aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-10 16:44:51 -0700
committerReynold Xin <rxin@databricks.com>2015-07-10 16:44:51 -0700
commitfb8807c9b04f27467b36fc9d0177ef92dd012670 (patch)
tree556101568eb707115d6d20013a0c1dcd8090b696
parent0772026c2fc88aa85423034006b758f6ff0cc2ed (diff)
downloadspark-fb8807c9b04f27467b36fc9d0177ef92dd012670.tar.gz
spark-fb8807c9b04f27467b36fc9d0177ef92dd012670.tar.bz2
spark-fb8807c9b04f27467b36fc9d0177ef92dd012670.zip
[SPARK-7078] [SPARK-7079] Binary processing sort for Spark SQL
This patch adds a cache-friendly external sorter which operates on serialized bytes and uses this sorter to implement a new sort operator for Spark SQL and DataFrames. ### Overview of the new sorter The new sorter design is inspired by [Alphasort](http://research.microsoft.com/pubs/68249/alphasort.doc) and implements a key-prefix optimization in order to improve the cache friendliness of the sort. In naive sort implementations, the sorting algorithm operates on an array of record pointers. To compare two records for ordering, the sorter must dereference these pointers, which likely involves random memory access, then compare the objects themselves. ![image](https://cloud.githubusercontent.com/assets/50748/8611390/3b1402ae-2675-11e5-8308-1a10bf347e6e.png) In a key-prefix sort, the sort operates on an array which stores the record pointer alongside a prefix of the record's key. When comparing two records for ordering, the sorter first compares the the stored key prefixes. If the ordering can be determined from the key prefixes (i.e. the prefixes are unequal), then the sort can avoid directly comparing the records, avoiding random memory accesses and full record comparisons. For example, if we're sorting a list of strings then we can store the first 8 bytes of the UTF-8 encoded string as the key-prefix and can perform unsigned byte-at-a-time comparisons to determine the ordering of strings based on their prefixes, only resorting to full comparisons for strings that share a common prefix. In cases where the sort key can fit entirely in the space allotted for the key prefix (e.g. the sorting key is an integer), we completely avoid direct record comparison. In this patch's implementation of key-prefix sorting, our sorter's internal array stores a 64-bit long and 64-bit pointer for each record being sorted. The key prefixes are generated by the user when inserting records into the sorter, which uses a user-defined comparison function for comparing them. The `PrefixComparators` object implements a set of comparators for many common types, including primitive numeric types and UTF-8 strings. The actual sorting is implemented by `UnsafeInMemorySorter`. Most consumers will not use this directly, but instead will use `UnsafeExternalSorter`, a class which implements a sort that can spill to disk in response to memory pressure. Internally, `UnsafeExternalSorter` creates `UnsafeInMemorySorters` to perform sorting and uses `UnsafeSortSpillReader/Writer` to spill and read back runs of sorted records and `UnsafeSortSpillMerger` to merge multiple sorted spills into a single sorted iterator. This external sorter integrates with Spark's existing ShuffleMemoryManager for controlling spilling. Many parts of this sorter's design are based on / copied from the more specialized external sort implementation that I designed for the new UnsafeShuffleManager write path; see #5868 for more details on that patch. ### Sorting rows in Spark SQL For now, `UnsafeExternalSorter` is only used by Spark SQL, which uses it to implement a new sort operator, `UnsafeExternalSort`. This sort operator uses a SQL-specific class called `UnsafeExternalRowSorter` that configures an `UnsafeExternalSorter` to use prefix generators and comparators that operate on rows encoded in the UnsafeRow format that was designed for Project Tungsten. I used some interesting unit-testing techniques to test this patch's SQL-specific components. `UnsafeExternalSortSuite` uses the SQL random data generators introduced in #7176 to test the UnsafeSort operator with all atomic types both with and without nullability and in both ascending and descending sort orders. `PrefixComparatorsSuite` contains a cool use of ScalaCheck + ScalaTest's `GeneratorDrivenPropertyChecks` in order to test UTF8String prefix comparison. ### Misc. additional improvements made in this patch This patch made several miscellaneous improvements to related code in Spark SQL: - The logic for selecting physical sort operator implementations, which was partially duplicated in both `Exchange` and `SparkStrategies, has now been consolidated into a `getSortOperator()` helper function in `SparkStrategies`. - The `SparkPlanTest` unit testing helper trait has been extended with new methods for comparing the output produced by two different physical plans. This makes it easy to write tests which assert that two physical operator implementations should produce the same output. I also added a method for disabling the implicit sorting of outputs prior to comparing them, a change which is necessary in order to be able to write proper SparkPlan tests for sort operators. ### Tasks deferred to followup patches While most of this patch's features are reasonably well-tested and complete, there are a number of tasks that are intentionally being deferred to followup patches: - Add tests which mock the ShuffleMemoryManager to check that memory pressure properly triggers spilling (there are examples of this type of test in #5868). - Add tests to ensure that spill files are properly cleaned up after errors. I'd like to do this in the context of a patch which introduces more general metrics for ensuring proper cleanup of tasks' temporary files; see https://issues.apache.org/jira/browse/SPARK-8966 for more details. - Metrics integration: there are some open questions regarding how to track / report spill metrics for non-shuffle operations, so I've deferred most of the IO / shuffle metrics integration for now. - Performance profiling. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/6444) <!-- Reviewable:end --> Author: Josh Rosen <joshrosen@databricks.com> Closes #6444 from JoshRosen/sql-external-sort and squashes the following commits: 6beb467 [Josh Rosen] Remove a bunch of overloaded methods to avoid default args. issue 2bbac9c [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort 35dad9f [Josh Rosen] Make sortAnswers = false the default in SparkPlanTest 5135200 [Josh Rosen] Fix spill reading for large rows; add test 2f48777 [Josh Rosen] Add test and fix bug for sorting empty arrays d1e28bc [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort cd05866 [Josh Rosen] Fix scalastyle 3947fc1 [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort d13ac55 [Josh Rosen] Hacky approach to copying of UnsafeRows for sort followed by limit. 845bea3 [Josh Rosen] Remove unnecessary zeroing of row conversion buffer c56ec18 [Josh Rosen] Clean up final row copying code. d31f180 [Josh Rosen] Re-enable NullType sorting test now that SPARK-8868 is fixed 844f4ca [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort 293f109 [Josh Rosen] Add missing license header. f99a612 [Josh Rosen] Fix bugs in string prefix comparison. 9d00afc [Josh Rosen] Clean up prefix comparators for integral types 88aff18 [Josh Rosen] NULL_PREFIX has to be negative infinity for floating point types 613e16f [Josh Rosen] Test with larger data. 1d7ffaa [Josh Rosen] Somewhat hacky fix for descending sorts 08701e7 [Josh Rosen] Fix prefix comparison of null primitives. b86e684 [Josh Rosen] Set global = true in UnsafeExternalSortSuite. 1c7bad8 [Josh Rosen] Make sorting of answers explicit in SparkPlanTest.checkAnswer(). b81a920 [Josh Rosen] Temporarily enable only the passing sort tests 5d6109d [Josh Rosen] Fix inconsistent handling / encoding of record lengths. 87b6ed9 [Josh Rosen] Fix critical issues in test which led to false negatives. 8d7fbe7 [Josh Rosen] Fixes to multiple spilling-related bugs. 82e21c1 [Josh Rosen] Force spilling in UnsafeExternalSortSuite. 88b72db [Josh Rosen] Test ascending and descending sort orders. f27be09 [Josh Rosen] Fix tests by binding attributes. 0a79d39 [Josh Rosen] Revert "Undo part of a SparkPlanTest change in #7162 that broke my test." 7c3c864 [Josh Rosen] Undo part of a SparkPlanTest change in #7162 that broke my test. 9969c14 [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort 5822e6f [Josh Rosen] Fix test compilation issue 939f824 [Josh Rosen] Remove code gen experiment. 0dfe919 [Josh Rosen] Implement prefix sort for strings (albeit inefficiently). 66a813e [Josh Rosen] Prefix comparators for float and double b310c88 [Josh Rosen] Integrate prefix comparators for Int and Long (others coming soon) 95058d9 [Josh Rosen] Add missing SortPrefixUtils file 4c37ba6 [Josh Rosen] Add tests for sorting on all primitive types. 6890863 [Josh Rosen] Fix memory leak on empty inputs. d246e29 [Josh Rosen] Fix consideration of column types when choosing sort implementation. 6b156fb [Josh Rosen] Some WIP work on prefix comparison. 7f875f9 [Josh Rosen] Commit failing test demonstrating bug in handling objects in spills 41b8881 [Josh Rosen] Get UnsafeInMemorySorterSuite to pass (WIP) 90c2b6a [Josh Rosen] Update test name 6d6a1e6 [Josh Rosen] Centralize logic for picking sort operator implementations 9869ec2 [Josh Rosen] Clean up Exchange code a bit 82bb0ec [Josh Rosen] Fix IntelliJ complaint due to negated if condition 1db845a [Josh Rosen] Many more changes to harmonize with shuffle sorter ebf9eea [Josh Rosen] Harmonization with shuffle's unsafe sorter 206bfa2 [Josh Rosen] Add some missing newlines at the ends of files 26c8931 [Josh Rosen] Back out some Hive changes that aren't needed anymore 62f0bb8 [Josh Rosen] Update to reflect SparkPlanTest changes 21d7d93 [Josh Rosen] Back out of BlockObjectWriter change 7eafecf [Josh Rosen] Port test to SparkPlanTest d468a88 [Josh Rosen] Update for InternalRow refactoring 269cf86 [Josh Rosen] Back out SMJ operator change; isolate changes to selection of sort op. 1b841ca [Josh Rosen] WIP towards copying b420a71 [Josh Rosen] Move most of the existing SMJ code into Java. dfdb93f [Josh Rosen] SparkFunSuite change 73cc761 [Josh Rosen] Fix whitespace 9cc98f5 [Josh Rosen] Move more code to Java; fix bugs in UnsafeRowConverter length type. c8792de [Josh Rosen] Remove some debug logging dda6752 [Josh Rosen] Commit some missing code from an old git stash. 58f36d0 [Josh Rosen] Merge in a sketch of a unit test for the new sorter (now failing). 2bd8c9a [Josh Rosen] Import my original tests and get them to pass. d5d3106 [Josh Rosen] WIP towards external sorter for Spark SQL.
-rw-r--r--core/pom.xml20
-rw-r--r--core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java (renamed from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java)9
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java1
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java29
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java109
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java37
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java31
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java282
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java189
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java80
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java35
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java91
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java98
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java146
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java202
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java139
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala50
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java7
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java216
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala97
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala73
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala253
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala104
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala21
28 files changed, 2254 insertions, 138 deletions
diff --git a/core/pom.xml b/core/pom.xml
index aee0d92620..558cc3fb9f 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -343,28 +343,28 @@
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.mockito</groupId>
- <artifactId>mockito-core</artifactId>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-library</artifactId>
<scope>test</scope>
</dependency>
<dependency>
- <groupId>junit</groupId>
- <artifactId>junit</artifactId>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.hamcrest</groupId>
- <artifactId>hamcrest-core</artifactId>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.hamcrest</groupId>
- <artifactId>hamcrest-library</artifactId>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
index 3f746b886b..0399abc63c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
+++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.serializer;
import java.io.IOException;
import java.io.InputStream;
@@ -24,9 +24,7 @@ import java.nio.ByteBuffer;
import scala.reflect.ClassTag;
-import org.apache.spark.serializer.DeserializationStream;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.PlatformDependent;
/**
@@ -35,7 +33,8 @@ import org.apache.spark.unsafe.PlatformDependent;
* `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
* around this, we pass a dummy no-op serializer.
*/
-final class DummySerializerInstance extends SerializerInstance {
+@Private
+public final class DummySerializerInstance extends SerializerInstance {
public static final DummySerializerInstance INSTANCE = new DummySerializerInstance();
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 9e9ed94b78..5628957320 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -30,6 +30,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java
new file mode 100644
index 0000000000..45b78829e4
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific
+ * comparisons, such as lexicographic comparison for strings.
+ */
+@Private
+public abstract class PrefixComparator {
+ public abstract int compare(long prefix1, long prefix2);
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
new file mode 100644
index 0000000000..438742565c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import com.google.common.base.Charsets;
+import com.google.common.primitives.Longs;
+import com.google.common.primitives.UnsignedBytes;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.unsafe.types.UTF8String;
+
+@Private
+public class PrefixComparators {
+ private PrefixComparators() {}
+
+ public static final StringPrefixComparator STRING = new StringPrefixComparator();
+ public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator();
+ public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
+ public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
+
+ public static final class StringPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ // TODO: can done more efficiently
+ byte[] a = Longs.toByteArray(aPrefix);
+ byte[] b = Longs.toByteArray(bPrefix);
+ for (int i = 0; i < 8; i++) {
+ int c = UnsignedBytes.compare(a[i], b[i]);
+ if (c != 0) return c;
+ }
+ return 0;
+ }
+
+ public long computePrefix(byte[] bytes) {
+ if (bytes == null) {
+ return 0L;
+ } else {
+ byte[] padded = new byte[8];
+ System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8));
+ return Longs.fromByteArray(padded);
+ }
+ }
+
+ public long computePrefix(String value) {
+ return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8));
+ }
+
+ public long computePrefix(UTF8String value) {
+ return value == null ? 0L : computePrefix(value.getBytes());
+ }
+ }
+
+ /**
+ * Prefix comparator for all integral types (boolean, byte, short, int, long).
+ */
+ public static final class IntegralPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long a, long b) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public final long NULL_PREFIX = Long.MIN_VALUE;
+ }
+
+ public static final class FloatPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ float a = Float.intBitsToFloat((int) aPrefix);
+ float b = Float.intBitsToFloat((int) bPrefix);
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public long computePrefix(float value) {
+ return Float.floatToIntBits(value) & 0xffffffffL;
+ }
+
+ public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY);
+ }
+
+ public static final class DoublePrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ double a = Double.longBitsToDouble(aPrefix);
+ double b = Double.longBitsToDouble(bPrefix);
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public long computePrefix(double value) {
+ return Double.doubleToLongBits(value);
+ }
+
+ public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
new file mode 100644
index 0000000000..09e4258792
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+/**
+ * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte
+ * prefix, this may simply return 0.
+ */
+public abstract class RecordComparator {
+
+ /**
+ * Compare two records for order.
+ *
+ * @return a negative integer, zero, or a positive integer as the first record is less than,
+ * equal to, or greater than the second.
+ */
+ public abstract int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset);
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
new file mode 100644
index 0000000000..0c4ebde407
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+final class RecordPointerAndKeyPrefix {
+ /**
+ * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
+ * description of how these addresses are encoded.
+ */
+ public long recordPointer;
+
+ /**
+ * A key prefix, for use in comparisons.
+ */
+ public long keyPrefix;
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
new file mode 100644
index 0000000000..4d6731ee60
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.LinkedList;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * External sorter based on {@link UnsafeInMemorySorter}.
+ */
+public final class UnsafeExternalSorter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
+
+ private static final int PAGE_SIZE = 1 << 27; // 128 megabytes
+ @VisibleForTesting
+ static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
+
+ private final PrefixComparator prefixComparator;
+ private final RecordComparator recordComparator;
+ private final int initialSize;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final BlockManager blockManager;
+ private final TaskContext taskContext;
+ private ShuffleWriteMetrics writeMetrics;
+
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
+
+ // These variables are reset after spilling:
+ private UnsafeInMemorySorter sorter;
+ private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+ private long freeSpaceInCurrentPage = 0;
+
+ private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
+
+ public UnsafeExternalSorter(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int initialSize,
+ SparkConf conf) throws IOException {
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.initialSize = initialSize;
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ initializeForWriting();
+ }
+
+ // TODO: metrics tracking + integration with shuffle write metrics
+ // need to connect the write metrics to task metrics so we count the spill IO somewhere.
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeForWriting() throws IOException {
+ this.writeMetrics = new ShuffleWriteMetrics();
+ // TODO: move this sizing calculation logic into a static method of sorter:
+ final long memoryRequested = initialSize * 8L * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+
+ this.sorter =
+ new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize);
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ public void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spillWriters.size(),
+ spillWriters.size() > 1 ? " times" : " time");
+
+ final UnsafeSorterSpillWriter spillWriter =
+ new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
+ sorter.numRecords());
+ spillWriters.add(spillWriter);
+ final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final Object baseObject = sortedRecords.getBaseObject();
+ final long baseOffset = sortedRecords.getBaseOffset();
+ final int recordLength = sortedRecords.getRecordLength();
+ spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+ }
+ spillWriter.close();
+ final long sorterMemoryUsage = sorter.getMemoryUsage();
+ sorter = null;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+ initializeForWriting();
+ }
+
+ private long getMemoryUsage() {
+ return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+ }
+
+ public long freeMemory() {
+ long memoryFreed = 0;
+ for (MemoryBlock block : allocatedPages) {
+ memoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ allocatedPages.clear();
+ currentPage = null;
+ currentPagePosition = -1;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+
+ * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ */
+ private boolean haveSpaceForRecord(int requiredSpace) {
+ assert (requiredSpace > 0);
+ return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+ */
+ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ // TODO: merge these steps to first calculate total memory requirements for this insert,
+ // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
+ // data page.
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to expand sort pointer array");
+ final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ sorter.expandPointerArray();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+
+ if (requiredSpace > freeSpaceInCurrentPage) {
+ logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+ freeSpaceInCurrentPage);
+ // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+ // without using the free space at the end of the current page. We should also do this for
+ // BytesToBytesMap.
+ if (requiredSpace > PAGE_SIZE) {
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+ PAGE_SIZE + ")");
+ } else {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquired < PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+ }
+ }
+ currentPage = memoryManager.allocatePage(PAGE_SIZE);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = PAGE_SIZE;
+ allocatedPages.add(currentPage);
+ }
+ }
+ }
+
+ /**
+ * Write a record to the sorter.
+ */
+ public void insertRecord(
+ Object recordBaseObject,
+ long recordBaseOffset,
+ int lengthInBytes,
+ long prefix) throws IOException {
+ // Need 4 bytes to store the record length.
+ final int totalSpaceRequired = lengthInBytes + 4;
+ if (!haveSpaceForRecord(totalSpaceRequired)) {
+ allocateSpaceForRecord(totalSpaceRequired);
+ }
+
+ final long recordAddress =
+ memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ final Object dataPageBaseObject = currentPage.getBaseObject();
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+ currentPagePosition += 4;
+ PlatformDependent.copyMemory(
+ recordBaseObject,
+ recordBaseOffset,
+ dataPageBaseObject,
+ currentPagePosition,
+ lengthInBytes);
+ currentPagePosition += lengthInBytes;
+
+ sorter.insertRecord(recordAddress, prefix);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
+ int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
+ if (spillWriters.isEmpty()) {
+ return inMemoryIterator;
+ } else {
+ final UnsafeSorterSpillMerger spillMerger =
+ new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
+ for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
+ spillMerger.addSpill(spillWriter.getReader(blockManager));
+ }
+ spillWriters.clear();
+ if (inMemoryIterator.hasNext()) {
+ spillMerger.addSpill(inMemoryIterator);
+ }
+ return spillMerger.getSortedIterator();
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
new file mode 100644
index 0000000000..fc34ad9cff
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Comparator;
+
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
+ * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm
+ * compares records, it will first compare the stored key prefixes; if the prefixes are not equal,
+ * then we do not need to traverse the record pointers to compare the actual records. Avoiding these
+ * random memory accesses improves cache hit rates.
+ */
+public final class UnsafeInMemorySorter {
+
+ private static final class SortComparator implements Comparator<RecordPointerAndKeyPrefix> {
+
+ private final RecordComparator recordComparator;
+ private final PrefixComparator prefixComparator;
+ private final TaskMemoryManager memoryManager;
+
+ SortComparator(
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ TaskMemoryManager memoryManager) {
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.memoryManager = memoryManager;
+ }
+
+ @Override
+ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
+ final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix);
+ if (prefixComparisonResult == 0) {
+ final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
+ final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length
+ final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
+ final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length
+ return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ }
+
+ private final TaskMemoryManager memoryManager;
+ private final Sorter<RecordPointerAndKeyPrefix, long[]> sorter;
+ private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
+
+ /**
+ * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the sort buffer where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public UnsafeInMemorySorter(
+ final TaskMemoryManager memoryManager,
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize * 2];
+ this.memoryManager = memoryManager;
+ this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
+ this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ }
+
+ /**
+ * @return the number of records that have been inserted into this sorter.
+ */
+ public int numRecords() {
+ return pointerArrayInsertPosition / 2;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 2 < pointerArray.length;
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ /**
+ * Inserts a record to be sorted. Assumes that the record pointer points to a record length
+ * stored as a 4-byte integer, followed by the record's bytes.
+ *
+ * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
+ * @param keyPrefix a user-defined key prefix
+ */
+ public void insertRecord(long recordPointer, long keyPrefix) {
+ if (!hasSpaceForAnotherRecord()) {
+ expandPointerArray();
+ }
+ pointerArray[pointerArrayInsertPosition] = recordPointer;
+ pointerArrayInsertPosition++;
+ pointerArray[pointerArrayInsertPosition] = keyPrefix;
+ pointerArrayInsertPosition++;
+ }
+
+ private static final class SortedIterator extends UnsafeSorterIterator {
+
+ private final TaskMemoryManager memoryManager;
+ private final int sortBufferInsertPosition;
+ private final long[] sortBuffer;
+ private int position = 0;
+ private Object baseObject;
+ private long baseOffset;
+ private long keyPrefix;
+ private int recordLength;
+
+ SortedIterator(
+ TaskMemoryManager memoryManager,
+ int sortBufferInsertPosition,
+ long[] sortBuffer) {
+ this.memoryManager = memoryManager;
+ this.sortBufferInsertPosition = sortBufferInsertPosition;
+ this.sortBuffer = sortBuffer;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return position < sortBufferInsertPosition;
+ }
+
+ @Override
+ public void loadNext() {
+ // This pointer points to a 4-byte record length, followed by the record's bytes
+ final long recordPointer = sortBuffer[position];
+ baseObject = memoryManager.getPage(recordPointer);
+ baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
+ recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4);
+ keyPrefix = sortBuffer[position + 1];
+ position += 2;
+ }
+
+ @Override
+ public Object getBaseObject() { return baseObject; }
+
+ @Override
+ public long getBaseOffset() { return baseOffset; }
+
+ @Override
+ public int getRecordLength() { return recordLength; }
+
+ @Override
+ public long getKeyPrefix() { return keyPrefix; }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order. For efficiency, all calls to
+ * {@code next()} will return the same mutable object.
+ */
+ public UnsafeSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
+ return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
new file mode 100644
index 0000000000..d09c728a7a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+/**
+ * Supports sorting an array of (record pointer, key prefix) pairs.
+ * Used in {@link UnsafeInMemorySorter}.
+ * <p>
+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, long[]> {
+
+ public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
+
+ private UnsafeSortDataFormat() { }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix newKey() {
+ return new RecordPointerAndKeyPrefix();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
+ reuse.recordPointer = data[pos * 2];
+ reuse.keyPrefix = data[pos * 2 + 1];
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ long tempPointer = data[pos0 * 2];
+ long tempKeyPrefix = data[pos0 * 2 + 1];
+ data[pos0 * 2] = data[pos1 * 2];
+ data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
+ data[pos1 * 2] = tempPointer;
+ data[pos1 * 2 + 1] = tempKeyPrefix;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos * 2] = src[srcPos * 2];
+ dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
+ return new long[length * 2];
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
new file mode 100644
index 0000000000..16ac2e8d82
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+
+public abstract class UnsafeSorterIterator {
+
+ public abstract boolean hasNext();
+
+ public abstract void loadNext() throws IOException;
+
+ public abstract Object getBaseObject();
+
+ public abstract long getBaseOffset();
+
+ public abstract int getRecordLength();
+
+ public abstract long getKeyPrefix();
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
new file mode 100644
index 0000000000..8272c2a5be
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.PriorityQueue;
+
+final class UnsafeSorterSpillMerger {
+
+ private final PriorityQueue<UnsafeSorterIterator> priorityQueue;
+
+ public UnsafeSorterSpillMerger(
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ final int numSpills) {
+ final Comparator<UnsafeSorterIterator> comparator = new Comparator<UnsafeSorterIterator>() {
+
+ @Override
+ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
+ final int prefixComparisonResult =
+ prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
+ if (prefixComparisonResult == 0) {
+ return recordComparator.compare(
+ left.getBaseObject(), left.getBaseOffset(),
+ right.getBaseObject(), right.getBaseOffset());
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ };
+ priorityQueue = new PriorityQueue<UnsafeSorterIterator>(numSpills, comparator);
+ }
+
+ public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ }
+ priorityQueue.add(spillReader);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ return new UnsafeSorterIterator() {
+
+ private UnsafeSorterIterator spillReader;
+
+ @Override
+ public boolean hasNext() {
+ return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ if (spillReader != null) {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ priorityQueue.add(spillReader);
+ }
+ }
+ spillReader = priorityQueue.remove();
+ }
+
+ @Override
+ public Object getBaseObject() { return spillReader.getBaseObject(); }
+
+ @Override
+ public long getBaseOffset() { return spillReader.getBaseOffset(); }
+
+ @Override
+ public int getRecordLength() { return spillReader.getRecordLength(); }
+
+ @Override
+ public long getKeyPrefix() { return spillReader.getKeyPrefix(); }
+ };
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
new file mode 100644
index 0000000000..29e9e0f30f
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.*;
+
+import com.google.common.io.ByteStreams;
+
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
+ * of the file format).
+ */
+final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+
+ private InputStream in;
+ private DataInputStream din;
+
+ // Variables that change with every record read:
+ private int recordLength;
+ private long keyPrefix;
+ private int numRecordsRemaining;
+
+ private byte[] arr = new byte[1024 * 1024];
+ private Object baseObject = arr;
+ private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
+
+ public UnsafeSorterSpillReader(
+ BlockManager blockManager,
+ File file,
+ BlockId blockId) throws IOException {
+ assert (file.length() > 0);
+ final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
+ this.in = blockManager.wrapForCompression(blockId, bs);
+ this.din = new DataInputStream(this.in);
+ numRecordsRemaining = din.readInt();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (numRecordsRemaining > 0);
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ recordLength = din.readInt();
+ keyPrefix = din.readLong();
+ if (recordLength > arr.length) {
+ arr = new byte[recordLength];
+ baseObject = arr;
+ }
+ ByteStreams.readFully(in, arr, 0, recordLength);
+ numRecordsRemaining--;
+ if (numRecordsRemaining == 0) {
+ in.close();
+ in = null;
+ din = null;
+ }
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return baseObject;
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return baseOffset;
+ }
+
+ @Override
+ public int getRecordLength() {
+ return recordLength;
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return keyPrefix;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
new file mode 100644
index 0000000000..b8d6665980
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Tuple2;
+
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.BlockObjectWriter;
+import org.apache.spark.storage.TempLocalBlockId;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Spills a list of sorted records to disk. Spill files have the following format:
+ *
+ * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
+ */
+final class UnsafeSorterSpillWriter {
+
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array.
+ private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ private final File file;
+ private final BlockId blockId;
+ private final int numRecordsToWrite;
+ private BlockObjectWriter writer;
+ private int numRecordsSpilled = 0;
+
+ public UnsafeSorterSpillWriter(
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ int numRecordsToWrite) throws IOException {
+ final Tuple2<TempLocalBlockId, File> spilledFileInfo =
+ blockManager.diskBlockManager().createTempLocalBlock();
+ this.file = spilledFileInfo._2();
+ this.blockId = spilledFileInfo._1();
+ this.numRecordsToWrite = numRecordsToWrite;
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ writer = blockManager.getDiskWriter(
+ blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics);
+ // Write the number of records
+ writeIntToBuffer(numRecordsToWrite, 0);
+ writer.write(writeBuffer, 0, 4);
+ }
+
+ // Based on DataOutputStream.writeLong.
+ private void writeLongToBuffer(long v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 56);
+ writeBuffer[offset + 1] = (byte)(v >>> 48);
+ writeBuffer[offset + 2] = (byte)(v >>> 40);
+ writeBuffer[offset + 3] = (byte)(v >>> 32);
+ writeBuffer[offset + 4] = (byte)(v >>> 24);
+ writeBuffer[offset + 5] = (byte)(v >>> 16);
+ writeBuffer[offset + 6] = (byte)(v >>> 8);
+ writeBuffer[offset + 7] = (byte)(v >>> 0);
+ }
+
+ // Based on DataOutputStream.writeInt.
+ private void writeIntToBuffer(int v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 24);
+ writeBuffer[offset + 1] = (byte)(v >>> 16);
+ writeBuffer[offset + 2] = (byte)(v >>> 8);
+ writeBuffer[offset + 3] = (byte)(v >>> 0);
+ }
+
+ /**
+ * Write a record to a spill file.
+ *
+ * @param baseObject the base object / memory page containing the record
+ * @param baseOffset the base offset which points directly to the record data.
+ * @param recordLength the length of the record.
+ * @param keyPrefix a sort key prefix
+ */
+ public void write(
+ Object baseObject,
+ long baseOffset,
+ int recordLength,
+ long keyPrefix) throws IOException {
+ if (numRecordsSpilled == numRecordsToWrite) {
+ throw new IllegalStateException(
+ "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite);
+ } else {
+ numRecordsSpilled++;
+ }
+ writeIntToBuffer(recordLength, 0);
+ writeLongToBuffer(keyPrefix, 4);
+ int dataRemaining = recordLength;
+ int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len
+ long recordReadPosition = baseOffset;
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
+ PlatformDependent.copyMemory(
+ baseObject,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+ toTransfer);
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE;
+ }
+ if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) {
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer));
+ }
+ writer.recordWritten();
+ }
+
+ public void close() throws IOException {
+ writer.commitAndClose();
+ writer = null;
+ writeBuffer = null;
+ }
+
+ public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
+ return new UnsafeSorterSpillReader(blockManager, file, blockId);
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
new file mode 100644
index 0000000000..ea8755e21e
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.UUID;
+
+import scala.Tuple2;
+import scala.Tuple2$;
+import scala.runtime.AbstractFunction1;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeExternalSorterSuite {
+
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+
+ File tempDir;
+
+ private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ return stream;
+ }
+ }
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ tempDir = new File(Utils.createTempDir$default$1());
+ taskContext = mock(TaskContext.class);
+ when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
+ when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
+ @Override
+ public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
+ TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (BlockId) args[0],
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
+ .then(returnsSecondArg());
+ }
+
+ private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
+ final int[] arr = new int[] { value };
+ sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
+ }
+
+ @Test
+ public void testSortingOnlyByPrefix() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ insertNumber(sorter, 5);
+ insertNumber(sorter, 1);
+ insertNumber(sorter, 3);
+ sorter.spill();
+ insertNumber(sorter, 4);
+ sorter.spill();
+ insertNumber(sorter, 2);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(i, iter.getKeyPrefix());
+ assertEquals(4, iter.getRecordLength());
+ // TODO: read rest of value.
+ }
+
+ // TODO: test for cleanup:
+ // assert(tempDir.isEmpty)
+ }
+
+ @Test
+ public void testSortingEmptyArrays() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(0, iter.getKeyPrefix());
+ assertEquals(0, iter.getRecordLength());
+ }
+ }
+
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
new file mode 100644
index 0000000000..9095009305
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Arrays;
+
+import org.junit.Test;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.*;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.mock;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeInMemorySorterSuite {
+
+ private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) {
+ final byte[] strBytes = new byte[length];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, length);
+ return new String(strBytes);
+ }
+
+ @Test
+ public void testSortingEmptyInput() {
+ final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
+ mock(RecordComparator.class),
+ mock(PrefixComparator.class),
+ 100);
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ assert(!iter.hasNext());
+ }
+
+ @Test
+ public void testSortingOnlyByIntegerPrefix() throws Exception {
+ final String[] dataToSort = new String[] {
+ "Boba",
+ "Pearls",
+ "Tapioca",
+ "Taho",
+ "Condensed Milk",
+ "Jasmine",
+ "Milk Tea",
+ "Lychee",
+ "Mango"
+ };
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ final Object baseObject = dataPage.getBaseObject();
+ // Write the records into the data page:
+ long position = dataPage.getBaseOffset();
+ for (String str : dataToSort) {
+ final byte[] strBytes = str.getBytes("utf-8");
+ PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+ position += 4;
+ PlatformDependent.copyMemory(
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ position,
+ strBytes.length);
+ position += strBytes.length;
+ }
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+ // Compute key prefixes based on the records' partition ids
+ final HashPartitioner hashPartitioner = new HashPartitioner(4);
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
+ prefixComparator, dataToSort.length);
+ // Given a page of records, insert those records into the sorter one-by-one:
+ position = dataPage.getBaseOffset();
+ for (int i = 0; i < dataToSort.length; i++) {
+ // position now points to the start of a record (which holds its length).
+ final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position);
+ final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
+ final String str = getStringFromDataPage(baseObject, position + 4, recordLength);
+ final int partitionId = hashPartitioner.getPartition(str);
+ sorter.insertRecord(address, partitionId);
+ position += 4 + recordLength;
+ }
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ int iterLength = 0;
+ long prevPrefix = -1;
+ Arrays.sort(dataToSort);
+ while (iter.hasNext()) {
+ iter.loadNext();
+ final String str =
+ getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength());
+ final long keyPrefix = iter.getKeyPrefix();
+ assertThat(str, isIn(Arrays.asList(dataToSort)));
+ assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));
+ prevPrefix = keyPrefix;
+ iterLength++;
+ }
+ assertEquals(dataToSort.length, iterLength);
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
new file mode 100644
index 0000000000..dd505dfa7d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort
+
+import org.scalatest.prop.PropertyChecks
+
+import org.apache.spark.SparkFunSuite
+
+class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
+
+ test("String prefix comparator") {
+
+ def testPrefixComparison(s1: String, s2: String): Unit = {
+ val s1Prefix = PrefixComparators.STRING.computePrefix(s1)
+ val s2Prefix = PrefixComparators.STRING.computePrefix(s2)
+ val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
+ assert(
+ (prefixComparisonResult == 0) ||
+ (prefixComparisonResult < 0 && s1 < s2) ||
+ (prefixComparisonResult > 0 && s1 > s2))
+ }
+
+ // scalastyle:off
+ val regressionTests = Table(
+ ("s1", "s2"),
+ ("abc", "世界"),
+ ("你好", "世界"),
+ ("你好123", "你好122")
+ )
+ // scalastyle:on
+
+ forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index edb7202245..4b99030d10 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -61,9 +61,10 @@ public final class UnsafeRow extends MutableRow {
/** A pool to hold non-primitive objects */
private ObjectPool pool;
- Object getBaseObject() { return baseObject; }
- long getBaseOffset() { return baseOffset; }
- ObjectPool getPool() { return pool; }
+ public Object getBaseObject() { return baseObject; }
+ public long getBaseOffset() { return baseOffset; }
+ public int getSizeInBytes() { return sizeInBytes; }
+ public ObjectPool getPool() { return pool; }
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
new file mode 100644
index 0000000000..b94601cf6d
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import scala.collection.Iterator;
+import scala.math.Ordering;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.sql.AbstractScalaRowIterator;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
+import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
+
+final class UnsafeExternalRowSorter {
+
+ /**
+ * If positive, forces records to be spilled to disk at the given frequency (measured in numbers
+ * of records). This is only intended to be used in tests.
+ */
+ private int testSpillFrequency = 0;
+
+ private long numRowsInserted = 0;
+
+ private final StructType schema;
+ private final UnsafeRowConverter rowConverter;
+ private final PrefixComputer prefixComputer;
+ private final UnsafeExternalSorter sorter;
+ private byte[] rowConversionBuffer = new byte[1024 * 8];
+
+ public static abstract class PrefixComputer {
+ abstract long computePrefix(InternalRow row);
+ }
+
+ public UnsafeExternalRowSorter(
+ StructType schema,
+ Ordering<InternalRow> ordering,
+ PrefixComparator prefixComparator,
+ PrefixComputer prefixComputer) throws IOException {
+ this.schema = schema;
+ this.rowConverter = new UnsafeRowConverter(schema);
+ this.prefixComputer = prefixComputer;
+ final SparkEnv sparkEnv = SparkEnv.get();
+ final TaskContext taskContext = TaskContext.get();
+ sorter = new UnsafeExternalSorter(
+ taskContext.taskMemoryManager(),
+ sparkEnv.shuffleMemoryManager(),
+ sparkEnv.blockManager(),
+ taskContext,
+ new RowComparator(ordering, schema.length(), null),
+ prefixComparator,
+ 4096,
+ sparkEnv.conf()
+ );
+ }
+
+ /**
+ * Forces spills to occur every `frequency` records. Only for use in tests.
+ */
+ @VisibleForTesting
+ void setTestSpillFrequency(int frequency) {
+ assert frequency > 0 : "Frequency must be positive";
+ testSpillFrequency = frequency;
+ }
+
+ @VisibleForTesting
+ void insertRow(InternalRow row) throws IOException {
+ final int sizeRequirement = rowConverter.getSizeRequirement(row);
+ if (sizeRequirement > rowConversionBuffer.length) {
+ rowConversionBuffer = new byte[sizeRequirement];
+ }
+ final int bytesWritten = rowConverter.writeRow(
+ row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null);
+ assert (bytesWritten == sizeRequirement);
+ final long prefix = prefixComputer.computePrefix(row);
+ sorter.insertRecord(
+ rowConversionBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ sizeRequirement,
+ prefix
+ );
+ numRowsInserted++;
+ if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
+ spill();
+ }
+ }
+
+ @VisibleForTesting
+ void spill() throws IOException {
+ sorter.spill();
+ }
+
+ private void cleanupResources() {
+ sorter.freeMemory();
+ }
+
+ @VisibleForTesting
+ Iterator<InternalRow> sort() throws IOException {
+ try {
+ final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
+ if (!sortedIterator.hasNext()) {
+ // Since we won't ever call next() on an empty iterator, we need to clean up resources
+ // here in order to prevent memory leaks.
+ cleanupResources();
+ }
+ return new AbstractScalaRowIterator() {
+
+ private final int numFields = schema.length();
+ private final UnsafeRow row = new UnsafeRow();
+
+ @Override
+ public boolean hasNext() {
+ return sortedIterator.hasNext();
+ }
+
+ @Override
+ public InternalRow next() {
+ try {
+ sortedIterator.loadNext();
+ row.pointTo(
+ sortedIterator.getBaseObject(),
+ sortedIterator.getBaseOffset(),
+ numFields,
+ sortedIterator.getRecordLength(),
+ null);
+ if (!hasNext()) {
+ row.copy(); // so that we don't have dangling pointers to freed page
+ cleanupResources();
+ }
+ return row;
+ } catch (IOException e) {
+ cleanupResources();
+ // Scala iterators don't declare any checked exceptions, so we need to use this hack
+ // to re-throw the exception:
+ PlatformDependent.throwException(e);
+ }
+ throw new RuntimeException("Exception should have been re-thrown in next()");
+ };
+ };
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+
+ public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
+ while (inputIterator.hasNext()) {
+ insertRow(inputIterator.next());
+ }
+ return sort();
+ }
+
+ /**
+ * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise.
+ */
+ public static boolean supportsSchema(StructType schema) {
+ // TODO: add spilling note to explain why we do this for now:
+ for (StructField field : schema.fields()) {
+ if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static final class RowComparator extends RecordComparator {
+ private final Ordering<InternalRow> ordering;
+ private final int numFields;
+ private final ObjectPool objPool;
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+
+ public RowComparator(Ordering<InternalRow> ordering, int numFields, ObjectPool objPool) {
+ this.numFields = numFields;
+ this.ordering = ordering;
+ this.objPool = objPool;
+ }
+
+ @Override
+ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+ row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool);
+ row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool);
+ return ordering.compare(row1, row2);
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala
new file mode 100644
index 0000000000..cfefb13e77
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator
+ * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to
+ * `Row` in order to work around a spurious IntelliJ compiler error.
+ */
+private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 74d9334045..4b783e30d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -289,11 +289,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
}
val withSort = if (needSort) {
- if (sqlContext.conf.externalSortEnabled) {
- ExternalSort(rowOrdering, global = false, withShuffle)
- } else {
- Sort(rowOrdering, global = false, withShuffle)
- }
+ sqlContext.planner.BasicOperators.getSortOperator(
+ rowOrdering, global = false, withShuffle)
} else {
withShuffle
}
@@ -321,11 +318,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
- if (sqlContext.conf.externalSortEnabled) {
- ExternalSort(rowOrdering, global = false, child)
- } else {
- Sort(rowOrdering, global = false, child)
- }
+ sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
new file mode 100644
index 0000000000..2dee3542d6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
+
+
+object SortPrefixUtils {
+
+ /**
+ * A dummy prefix comparator which always claims that prefixes are equal. This is used in cases
+ * where we don't know how to generate or compare prefixes for a SortOrder.
+ */
+ private object NoOpPrefixComparator extends PrefixComparator {
+ override def compare(prefix1: Long, prefix2: Long): Int = 0
+ }
+
+ def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
+ sortOrder.dataType match {
+ case StringType => PrefixComparators.STRING
+ case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL
+ case FloatType => PrefixComparators.FLOAT
+ case DoubleType => PrefixComparators.DOUBLE
+ case _ => NoOpPrefixComparator
+ }
+ }
+
+ def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = {
+ sortOrder.dataType match {
+ case StringType => (row: InternalRow) => {
+ PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String])
+ }
+ case BooleanType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1
+ else 0
+ }
+ case ByteType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Byte]
+ }
+ case ShortType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Short]
+ }
+ case IntegerType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Int]
+ }
+ case LongType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Long]
+ }
+ case FloatType => (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX
+ else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
+ }
+ case DoubleType => (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX
+ else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
+ }
+ case _ => (row: InternalRow) => 0L
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 59b9b553a7..ce25af58b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -302,6 +302,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object BasicOperators extends Strategy {
def numPartitions: Int = self.numPartitions
+ /**
+ * Picks an appropriate sort operator.
+ *
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ */
+ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = {
+ if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) {
+ execution.UnsafeExternalSort(sortExprs, global, child)
+ } else if (sqlContext.conf.externalSortEnabled) {
+ execution.ExternalSort(sortExprs, global, child)
+ } else {
+ execution.Sort(sortExprs, global, child)
+ }
+ }
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommand(r) :: Nil
@@ -313,11 +329,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.SortPartitions(sortExprs, child) =>
// This sort only sorts tuples within a partition. Its requiredDistribution will be
// an UnspecifiedDistribution.
- execution.Sort(sortExprs, global = false, planLater(child)) :: Nil
- case logical.Sort(sortExprs, global, child) if sqlContext.conf.externalSortEnabled =>
- execution.ExternalSort(sortExprs, global, planLater(child)):: Nil
+ getSortOperator(sortExprs, global = false, planLater(child)) :: Nil
case logical.Sort(sortExprs, global, child) =>
- execution.Sort(sortExprs, global, planLater(child)):: Nil
+ getSortOperator(sortExprs, global, planLater(child)):: Nil
case logical.Project(projectList, child) =>
execution.Project(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index de14e6ad79..4c063c299b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.types.StructType
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.collection.ExternalSorter
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.{HashPartitioner, SparkEnv}
@@ -248,6 +250,77 @@ case class ExternalSort(
/**
* :: DeveloperApi ::
+ * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of
+ * Project Tungsten).
+ *
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
+ * spill every `frequency` records.
+ */
+@DeveloperApi
+case class UnsafeExternalSort(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ child: SparkPlan,
+ testSpillFrequency: Int = 0)
+ extends UnaryNode {
+
+ private[this] val schema: StructType = child.schema
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
+ assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
+ def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+ val ordering = newOrdering(sortOrder, child.output)
+ val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output)
+ // Hack until we generate separate comparator implementations for ascending vs. descending
+ // (or choose to codegen them):
+ val prefixComparator = {
+ val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+ if (sortOrder.head.direction == Descending) {
+ new PrefixComparator {
+ override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2)
+ }
+ } else {
+ comp
+ }
+ }
+ val prefixComputer = {
+ val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression)
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = prefixComputer(row)
+ }
+ }
+ val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
+ if (testSpillFrequency > 0) {
+ sorter.setTestSpillFrequency(testSpillFrequency)
+ }
+ sorter.sort(iterator)
+ }
+ child.execute().mapPartitions(doSort, preservesPartitioning = true)
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
+}
+
+@DeveloperApi
+object UnsafeExternalSort {
+ /**
+ * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise.
+ */
+ def supportsSchema(schema: StructType): Boolean = {
+ UnsafeExternalRowSorter.supportsSchema(schema)
+ }
+}
+
+
+/**
+ * :: DeveloperApi ::
* Return a new RDD that has exactly `numPartitions` partitions.
*/
@DeveloperApi
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index a1e3ca11b1..a2c10fdaf6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
class SortSuite extends SparkPlanTest {
@@ -33,12 +34,14 @@ class SortSuite extends SparkPlanTest {
checkAnswer(
input.toDF("a", "b", "c"),
- ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
- input.sorted)
+ ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan),
+ input.sortBy(t => (t._1, t._2)).map(Row.fromTuple),
+ sortAnswers = false)
checkAnswer(
input.toDF("a", "b", "c"),
- ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
- input.sortBy(t => (t._2, t._1)))
+ ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan),
+ input.sortBy(t => (t._2, t._1)).map(Row.fromTuple),
+ sortAnswers = false)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 108b1122f7..6a8f394545 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -17,18 +17,15 @@
package org.apache.spark.sql.execution
-import scala.language.implicitConversions
-import scala.reflect.runtime.universe.TypeTag
-import scala.util.control.NonFatal
-
import org.apache.spark.SparkFunSuite
-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util._
-
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame}
+import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
/**
* Base class for writing tests for individual physical operators. For an example of how this
@@ -49,12 +46,19 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
protected def checkAnswer(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer)
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ doCheckAnswer(
+ input :: Nil,
+ (plans: Seq[SparkPlan]) => planFunction(plans.head),
+ expectedAnswer,
+ sortAnswers)
}
/**
@@ -64,86 +68,131 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer(
+ protected def checkAnswer2(
left: DataFrame,
right: DataFrame,
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- checkAnswer(left :: right :: Nil,
- (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer)
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ doCheckAnswer(
+ left :: right :: Nil,
+ (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)),
+ expectedAnswer,
+ sortAnswers)
}
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
- * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
- * the physical operator that's being tested.
+ * @param planFunction a function which accepts a sequence of input SparkPlans and uses them to
+ * instantiate the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer(
+ protected def doCheckAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match {
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}
/**
- * Runs the plan and makes sure the answer matches the expected result.
+ * Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+ * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
+ * instantiate a reference implementation of the physical operator
+ * that's being tested. The result of executing this plan will be
+ * treated as the source-of-truth for the test.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer[A <: Product : TypeTag](
+ protected def checkThatPlansAgree(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(input, planFunction, expectedRows)
+ expectedPlanFunction: SparkPlan => SparkPlan,
+ sortAnswers: Boolean = true): Unit = {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
}
+}
- /**
- * Runs the plan and makes sure the answer matches the expected result.
- * @param left the left input data to be used.
- * @param right the right input data to be used.
- * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
- * the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
- */
- protected def checkAnswer[A <: Product : TypeTag](
- left: DataFrame,
- right: DataFrame,
- planFunction: (SparkPlan, SparkPlan) => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(left, right, planFunction, expectedRows)
- }
+/**
+ * Helper methods for writing tests of individual physical operators.
+ */
+object SparkPlanTest {
/**
- * Runs the plan and makes sure the answer matches the expected result.
+ * Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+ * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
+ * instantiate a reference implementation of the physical operator
+ * that's being tested. The result of executing this plan will be
+ * treated as the source-of-truth for the test.
*/
- protected def checkAnswer[A <: Product : TypeTag](
- input: Seq[DataFrame],
- planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(input, planFunction, expectedRows)
- }
+ def checkAnswer(
+ input: DataFrame,
+ planFunction: SparkPlan => SparkPlan,
+ expectedPlanFunction: SparkPlan => SparkPlan,
+ sortAnswers: Boolean): Option[String] = {
-}
+ val outputPlan = planFunction(input.queryExecution.sparkPlan)
+ val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
-/**
- * Helper methods for writing tests of individual physical operators.
- */
-object SparkPlanTest {
+ val expectedAnswer: Seq[Row] = try {
+ executePlan(expectedOutputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan to calculate expected answer:
+ | $expectedOutputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ val actualAnswer: Seq[Row] = try {
+ executePlan(outputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan:
+ | $outputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
+ s"""
+ | Results do not match.
+ | Actual result Spark plan:
+ | $outputPlan
+ | Expected result Spark plan:
+ | $expectedOutputPlan
+ | $errorMessage
+ """.stripMargin
+ }
+ }
/**
* Runs the plan and makes sure the answer matches the expected result.
@@ -151,28 +200,45 @@ object SparkPlanTest {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
def checkAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[Row]): Option[String] = {
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
- // A very simple resolver to make writing tests easier. In contrast to the real resolver
- // this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = TestSQLContext.prepareForExecution.execute(
- outputPlan transform {
- case plan: SparkPlan =>
- val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
- plan.transformExpressions {
- case UnresolvedAttribute(Seq(u)) =>
- inputMap.getOrElse(u,
- sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
- }
- }
- )
+ val sparkAnswer: Seq[Row] = try {
+ executePlan(outputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan:
+ | $outputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+ compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
+ s"""
+ | Results do not match for Spark plan:
+ | $outputPlan
+ | $errorMessage
+ """.stripMargin
+ }
+ }
+
+ private def compareAnswers(
+ sparkAnswer: Seq[Row],
+ expectedAnswer: Seq[Row],
+ sort: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -187,40 +253,43 @@ object SparkPlanTest {
case o => o
})
}
- converted.sortBy(_.toString())
- }
-
- val sparkAnswer: Seq[Row] = try {
- resolvedPlan.executeCollect().toSeq
- } catch {
- case NonFatal(e) =>
- val errorMessage =
- s"""
- | Exception thrown while executing Spark plan:
- | $outputPlan
- | == Exception ==
- | $e
- | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin
- return Some(errorMessage)
+ if (sort) {
+ converted.sortBy(_.toString())
+ } else {
+ converted
+ }
}
-
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
val errorMessage =
s"""
- | Results do not match for Spark plan:
- | $outputPlan
| == Results ==
| ${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ s"== Expected Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString()),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ s"== Actual Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
""".stripMargin
- return Some(errorMessage)
+ Some(errorMessage)
+ } else {
+ None
}
+ }
- None
+ private def executePlan(outputPlan: SparkPlan): Seq[Row] = {
+ // A very simple resolver to make writing tests easier. In contrast to the real resolver
+ // this is always case sensitive and does not try to handle scoping or complex type resolution.
+ val resolvedPlan = TestSQLContext.prepareForExecution.execute(
+ outputPlan transform {
+ case plan: SparkPlan =>
+ val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
+ plan.transformExpressions {
+ case UnresolvedAttribute(Seq(u)) =>
+ inputMap.getOrElse(u,
+ sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
+ }
+ }
+ )
+ resolvedPlan.executeCollect().toSeq
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
new file mode 100644
index 0000000000..4f4c1f2856
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types._
+
+class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+
+ override def beforeAll(): Unit = {
+ TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+ }
+
+ override def afterAll(): Unit = {
+ TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+ }
+
+ ignore("sort followed by limit should not leak memory") {
+ // TODO: this test is going to fail until we implement a proper iterator interface
+ // with a close() method.
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
+ (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ }
+
+ test("sort followed by limit") {
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
+ try {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
+ (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ } finally {
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
+
+ }
+ }
+
+ test("sorting does not crash for large inputs") {
+ val sortOrder = 'a.asc :: Nil
+ val stringLength = 1024 * 1024 * 2
+ checkThatPlansAgree(
+ Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
+ UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
+ Sort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+
+ // Test sorting on different data types
+ for (
+ dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
+ if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
+ nullable <- Seq(true, false);
+ sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
+ randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
+ ) {
+ test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
+ val inputData = Seq.fill(1000)(randomDataGenerator()).filter {
+ case d: Double => !d.isNaN
+ case f: Float => !java.lang.Float.isNaN(f)
+ case x => true
+ }
+ val inputDf = TestSQLContext.createDataFrame(
+ TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
+ StructType(StructField("a", dataType, nullable = true) :: Nil)
+ )
+ assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
+ checkThatPlansAgree(
+ inputDf,
+ UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23),
+ Sort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 5707d2fb30..2c27da596b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan}
import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
@@ -41,23 +42,23 @@ class OuterJoinSuite extends SparkPlanTest {
val condition = Some(LessThan('b, 'd))
test("shuffled hash outer join") {
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
@@ -65,24 +66,24 @@ class OuterJoinSuite extends SparkPlanTest {
(3, 3.0, null, null),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
}
test("broadcast hash outer join") {
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
}
}