aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml20
-rw-r--r--core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java (renamed from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java)9
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java1
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java29
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java109
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java37
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java31
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java282
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java189
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java80
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java35
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java91
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java98
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java146
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java202
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java139
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala50
17 files changed, 1533 insertions, 15 deletions
diff --git a/core/pom.xml b/core/pom.xml
index aee0d92620..558cc3fb9f 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -343,28 +343,28 @@
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.mockito</groupId>
- <artifactId>mockito-core</artifactId>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-library</artifactId>
<scope>test</scope>
</dependency>
<dependency>
- <groupId>junit</groupId>
- <artifactId>junit</artifactId>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.hamcrest</groupId>
- <artifactId>hamcrest-core</artifactId>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.hamcrest</groupId>
- <artifactId>hamcrest-library</artifactId>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
index 3f746b886b..0399abc63c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
+++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.serializer;
import java.io.IOException;
import java.io.InputStream;
@@ -24,9 +24,7 @@ import java.nio.ByteBuffer;
import scala.reflect.ClassTag;
-import org.apache.spark.serializer.DeserializationStream;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.PlatformDependent;
/**
@@ -35,7 +33,8 @@ import org.apache.spark.unsafe.PlatformDependent;
* `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
* around this, we pass a dummy no-op serializer.
*/
-final class DummySerializerInstance extends SerializerInstance {
+@Private
+public final class DummySerializerInstance extends SerializerInstance {
public static final DummySerializerInstance INSTANCE = new DummySerializerInstance();
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 9e9ed94b78..5628957320 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -30,6 +30,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java
new file mode 100644
index 0000000000..45b78829e4
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific
+ * comparisons, such as lexicographic comparison for strings.
+ */
+@Private
+public abstract class PrefixComparator {
+ public abstract int compare(long prefix1, long prefix2);
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
new file mode 100644
index 0000000000..438742565c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import com.google.common.base.Charsets;
+import com.google.common.primitives.Longs;
+import com.google.common.primitives.UnsignedBytes;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.unsafe.types.UTF8String;
+
+@Private
+public class PrefixComparators {
+ private PrefixComparators() {}
+
+ public static final StringPrefixComparator STRING = new StringPrefixComparator();
+ public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator();
+ public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
+ public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
+
+ public static final class StringPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ // TODO: can done more efficiently
+ byte[] a = Longs.toByteArray(aPrefix);
+ byte[] b = Longs.toByteArray(bPrefix);
+ for (int i = 0; i < 8; i++) {
+ int c = UnsignedBytes.compare(a[i], b[i]);
+ if (c != 0) return c;
+ }
+ return 0;
+ }
+
+ public long computePrefix(byte[] bytes) {
+ if (bytes == null) {
+ return 0L;
+ } else {
+ byte[] padded = new byte[8];
+ System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8));
+ return Longs.fromByteArray(padded);
+ }
+ }
+
+ public long computePrefix(String value) {
+ return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8));
+ }
+
+ public long computePrefix(UTF8String value) {
+ return value == null ? 0L : computePrefix(value.getBytes());
+ }
+ }
+
+ /**
+ * Prefix comparator for all integral types (boolean, byte, short, int, long).
+ */
+ public static final class IntegralPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long a, long b) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public final long NULL_PREFIX = Long.MIN_VALUE;
+ }
+
+ public static final class FloatPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ float a = Float.intBitsToFloat((int) aPrefix);
+ float b = Float.intBitsToFloat((int) bPrefix);
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public long computePrefix(float value) {
+ return Float.floatToIntBits(value) & 0xffffffffL;
+ }
+
+ public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY);
+ }
+
+ public static final class DoublePrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ double a = Double.longBitsToDouble(aPrefix);
+ double b = Double.longBitsToDouble(bPrefix);
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public long computePrefix(double value) {
+ return Double.doubleToLongBits(value);
+ }
+
+ public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
new file mode 100644
index 0000000000..09e4258792
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+/**
+ * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte
+ * prefix, this may simply return 0.
+ */
+public abstract class RecordComparator {
+
+ /**
+ * Compare two records for order.
+ *
+ * @return a negative integer, zero, or a positive integer as the first record is less than,
+ * equal to, or greater than the second.
+ */
+ public abstract int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset);
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
new file mode 100644
index 0000000000..0c4ebde407
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+final class RecordPointerAndKeyPrefix {
+ /**
+ * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
+ * description of how these addresses are encoded.
+ */
+ public long recordPointer;
+
+ /**
+ * A key prefix, for use in comparisons.
+ */
+ public long keyPrefix;
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
new file mode 100644
index 0000000000..4d6731ee60
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.LinkedList;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * External sorter based on {@link UnsafeInMemorySorter}.
+ */
+public final class UnsafeExternalSorter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
+
+ private static final int PAGE_SIZE = 1 << 27; // 128 megabytes
+ @VisibleForTesting
+ static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
+
+ private final PrefixComparator prefixComparator;
+ private final RecordComparator recordComparator;
+ private final int initialSize;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final BlockManager blockManager;
+ private final TaskContext taskContext;
+ private ShuffleWriteMetrics writeMetrics;
+
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
+
+ // These variables are reset after spilling:
+ private UnsafeInMemorySorter sorter;
+ private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+ private long freeSpaceInCurrentPage = 0;
+
+ private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
+
+ public UnsafeExternalSorter(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int initialSize,
+ SparkConf conf) throws IOException {
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.initialSize = initialSize;
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ initializeForWriting();
+ }
+
+ // TODO: metrics tracking + integration with shuffle write metrics
+ // need to connect the write metrics to task metrics so we count the spill IO somewhere.
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeForWriting() throws IOException {
+ this.writeMetrics = new ShuffleWriteMetrics();
+ // TODO: move this sizing calculation logic into a static method of sorter:
+ final long memoryRequested = initialSize * 8L * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+
+ this.sorter =
+ new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize);
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ public void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spillWriters.size(),
+ spillWriters.size() > 1 ? " times" : " time");
+
+ final UnsafeSorterSpillWriter spillWriter =
+ new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
+ sorter.numRecords());
+ spillWriters.add(spillWriter);
+ final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final Object baseObject = sortedRecords.getBaseObject();
+ final long baseOffset = sortedRecords.getBaseOffset();
+ final int recordLength = sortedRecords.getRecordLength();
+ spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+ }
+ spillWriter.close();
+ final long sorterMemoryUsage = sorter.getMemoryUsage();
+ sorter = null;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+ initializeForWriting();
+ }
+
+ private long getMemoryUsage() {
+ return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+ }
+
+ public long freeMemory() {
+ long memoryFreed = 0;
+ for (MemoryBlock block : allocatedPages) {
+ memoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ allocatedPages.clear();
+ currentPage = null;
+ currentPagePosition = -1;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+
+ * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ */
+ private boolean haveSpaceForRecord(int requiredSpace) {
+ assert (requiredSpace > 0);
+ return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+ */
+ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ // TODO: merge these steps to first calculate total memory requirements for this insert,
+ // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
+ // data page.
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to expand sort pointer array");
+ final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ sorter.expandPointerArray();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+
+ if (requiredSpace > freeSpaceInCurrentPage) {
+ logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+ freeSpaceInCurrentPage);
+ // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+ // without using the free space at the end of the current page. We should also do this for
+ // BytesToBytesMap.
+ if (requiredSpace > PAGE_SIZE) {
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+ PAGE_SIZE + ")");
+ } else {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquired < PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+ }
+ }
+ currentPage = memoryManager.allocatePage(PAGE_SIZE);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = PAGE_SIZE;
+ allocatedPages.add(currentPage);
+ }
+ }
+ }
+
+ /**
+ * Write a record to the sorter.
+ */
+ public void insertRecord(
+ Object recordBaseObject,
+ long recordBaseOffset,
+ int lengthInBytes,
+ long prefix) throws IOException {
+ // Need 4 bytes to store the record length.
+ final int totalSpaceRequired = lengthInBytes + 4;
+ if (!haveSpaceForRecord(totalSpaceRequired)) {
+ allocateSpaceForRecord(totalSpaceRequired);
+ }
+
+ final long recordAddress =
+ memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ final Object dataPageBaseObject = currentPage.getBaseObject();
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+ currentPagePosition += 4;
+ PlatformDependent.copyMemory(
+ recordBaseObject,
+ recordBaseOffset,
+ dataPageBaseObject,
+ currentPagePosition,
+ lengthInBytes);
+ currentPagePosition += lengthInBytes;
+
+ sorter.insertRecord(recordAddress, prefix);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
+ int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
+ if (spillWriters.isEmpty()) {
+ return inMemoryIterator;
+ } else {
+ final UnsafeSorterSpillMerger spillMerger =
+ new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
+ for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
+ spillMerger.addSpill(spillWriter.getReader(blockManager));
+ }
+ spillWriters.clear();
+ if (inMemoryIterator.hasNext()) {
+ spillMerger.addSpill(inMemoryIterator);
+ }
+ return spillMerger.getSortedIterator();
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
new file mode 100644
index 0000000000..fc34ad9cff
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Comparator;
+
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
+ * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm
+ * compares records, it will first compare the stored key prefixes; if the prefixes are not equal,
+ * then we do not need to traverse the record pointers to compare the actual records. Avoiding these
+ * random memory accesses improves cache hit rates.
+ */
+public final class UnsafeInMemorySorter {
+
+ private static final class SortComparator implements Comparator<RecordPointerAndKeyPrefix> {
+
+ private final RecordComparator recordComparator;
+ private final PrefixComparator prefixComparator;
+ private final TaskMemoryManager memoryManager;
+
+ SortComparator(
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ TaskMemoryManager memoryManager) {
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.memoryManager = memoryManager;
+ }
+
+ @Override
+ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
+ final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix);
+ if (prefixComparisonResult == 0) {
+ final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
+ final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length
+ final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
+ final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length
+ return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ }
+
+ private final TaskMemoryManager memoryManager;
+ private final Sorter<RecordPointerAndKeyPrefix, long[]> sorter;
+ private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
+
+ /**
+ * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the sort buffer where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public UnsafeInMemorySorter(
+ final TaskMemoryManager memoryManager,
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize * 2];
+ this.memoryManager = memoryManager;
+ this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
+ this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ }
+
+ /**
+ * @return the number of records that have been inserted into this sorter.
+ */
+ public int numRecords() {
+ return pointerArrayInsertPosition / 2;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 2 < pointerArray.length;
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ /**
+ * Inserts a record to be sorted. Assumes that the record pointer points to a record length
+ * stored as a 4-byte integer, followed by the record's bytes.
+ *
+ * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
+ * @param keyPrefix a user-defined key prefix
+ */
+ public void insertRecord(long recordPointer, long keyPrefix) {
+ if (!hasSpaceForAnotherRecord()) {
+ expandPointerArray();
+ }
+ pointerArray[pointerArrayInsertPosition] = recordPointer;
+ pointerArrayInsertPosition++;
+ pointerArray[pointerArrayInsertPosition] = keyPrefix;
+ pointerArrayInsertPosition++;
+ }
+
+ private static final class SortedIterator extends UnsafeSorterIterator {
+
+ private final TaskMemoryManager memoryManager;
+ private final int sortBufferInsertPosition;
+ private final long[] sortBuffer;
+ private int position = 0;
+ private Object baseObject;
+ private long baseOffset;
+ private long keyPrefix;
+ private int recordLength;
+
+ SortedIterator(
+ TaskMemoryManager memoryManager,
+ int sortBufferInsertPosition,
+ long[] sortBuffer) {
+ this.memoryManager = memoryManager;
+ this.sortBufferInsertPosition = sortBufferInsertPosition;
+ this.sortBuffer = sortBuffer;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return position < sortBufferInsertPosition;
+ }
+
+ @Override
+ public void loadNext() {
+ // This pointer points to a 4-byte record length, followed by the record's bytes
+ final long recordPointer = sortBuffer[position];
+ baseObject = memoryManager.getPage(recordPointer);
+ baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
+ recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4);
+ keyPrefix = sortBuffer[position + 1];
+ position += 2;
+ }
+
+ @Override
+ public Object getBaseObject() { return baseObject; }
+
+ @Override
+ public long getBaseOffset() { return baseOffset; }
+
+ @Override
+ public int getRecordLength() { return recordLength; }
+
+ @Override
+ public long getKeyPrefix() { return keyPrefix; }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order. For efficiency, all calls to
+ * {@code next()} will return the same mutable object.
+ */
+ public UnsafeSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
+ return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
new file mode 100644
index 0000000000..d09c728a7a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+/**
+ * Supports sorting an array of (record pointer, key prefix) pairs.
+ * Used in {@link UnsafeInMemorySorter}.
+ * <p>
+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, long[]> {
+
+ public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
+
+ private UnsafeSortDataFormat() { }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix newKey() {
+ return new RecordPointerAndKeyPrefix();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
+ reuse.recordPointer = data[pos * 2];
+ reuse.keyPrefix = data[pos * 2 + 1];
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ long tempPointer = data[pos0 * 2];
+ long tempKeyPrefix = data[pos0 * 2 + 1];
+ data[pos0 * 2] = data[pos1 * 2];
+ data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
+ data[pos1 * 2] = tempPointer;
+ data[pos1 * 2 + 1] = tempKeyPrefix;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos * 2] = src[srcPos * 2];
+ dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
+ return new long[length * 2];
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
new file mode 100644
index 0000000000..16ac2e8d82
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+
+public abstract class UnsafeSorterIterator {
+
+ public abstract boolean hasNext();
+
+ public abstract void loadNext() throws IOException;
+
+ public abstract Object getBaseObject();
+
+ public abstract long getBaseOffset();
+
+ public abstract int getRecordLength();
+
+ public abstract long getKeyPrefix();
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
new file mode 100644
index 0000000000..8272c2a5be
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.PriorityQueue;
+
+final class UnsafeSorterSpillMerger {
+
+ private final PriorityQueue<UnsafeSorterIterator> priorityQueue;
+
+ public UnsafeSorterSpillMerger(
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ final int numSpills) {
+ final Comparator<UnsafeSorterIterator> comparator = new Comparator<UnsafeSorterIterator>() {
+
+ @Override
+ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
+ final int prefixComparisonResult =
+ prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
+ if (prefixComparisonResult == 0) {
+ return recordComparator.compare(
+ left.getBaseObject(), left.getBaseOffset(),
+ right.getBaseObject(), right.getBaseOffset());
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ };
+ priorityQueue = new PriorityQueue<UnsafeSorterIterator>(numSpills, comparator);
+ }
+
+ public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ }
+ priorityQueue.add(spillReader);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ return new UnsafeSorterIterator() {
+
+ private UnsafeSorterIterator spillReader;
+
+ @Override
+ public boolean hasNext() {
+ return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ if (spillReader != null) {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ priorityQueue.add(spillReader);
+ }
+ }
+ spillReader = priorityQueue.remove();
+ }
+
+ @Override
+ public Object getBaseObject() { return spillReader.getBaseObject(); }
+
+ @Override
+ public long getBaseOffset() { return spillReader.getBaseOffset(); }
+
+ @Override
+ public int getRecordLength() { return spillReader.getRecordLength(); }
+
+ @Override
+ public long getKeyPrefix() { return spillReader.getKeyPrefix(); }
+ };
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
new file mode 100644
index 0000000000..29e9e0f30f
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.*;
+
+import com.google.common.io.ByteStreams;
+
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
+ * of the file format).
+ */
+final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+
+ private InputStream in;
+ private DataInputStream din;
+
+ // Variables that change with every record read:
+ private int recordLength;
+ private long keyPrefix;
+ private int numRecordsRemaining;
+
+ private byte[] arr = new byte[1024 * 1024];
+ private Object baseObject = arr;
+ private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
+
+ public UnsafeSorterSpillReader(
+ BlockManager blockManager,
+ File file,
+ BlockId blockId) throws IOException {
+ assert (file.length() > 0);
+ final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
+ this.in = blockManager.wrapForCompression(blockId, bs);
+ this.din = new DataInputStream(this.in);
+ numRecordsRemaining = din.readInt();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (numRecordsRemaining > 0);
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ recordLength = din.readInt();
+ keyPrefix = din.readLong();
+ if (recordLength > arr.length) {
+ arr = new byte[recordLength];
+ baseObject = arr;
+ }
+ ByteStreams.readFully(in, arr, 0, recordLength);
+ numRecordsRemaining--;
+ if (numRecordsRemaining == 0) {
+ in.close();
+ in = null;
+ din = null;
+ }
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return baseObject;
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return baseOffset;
+ }
+
+ @Override
+ public int getRecordLength() {
+ return recordLength;
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return keyPrefix;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
new file mode 100644
index 0000000000..b8d6665980
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Tuple2;
+
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.BlockObjectWriter;
+import org.apache.spark.storage.TempLocalBlockId;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Spills a list of sorted records to disk. Spill files have the following format:
+ *
+ * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
+ */
+final class UnsafeSorterSpillWriter {
+
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array.
+ private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ private final File file;
+ private final BlockId blockId;
+ private final int numRecordsToWrite;
+ private BlockObjectWriter writer;
+ private int numRecordsSpilled = 0;
+
+ public UnsafeSorterSpillWriter(
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ int numRecordsToWrite) throws IOException {
+ final Tuple2<TempLocalBlockId, File> spilledFileInfo =
+ blockManager.diskBlockManager().createTempLocalBlock();
+ this.file = spilledFileInfo._2();
+ this.blockId = spilledFileInfo._1();
+ this.numRecordsToWrite = numRecordsToWrite;
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ writer = blockManager.getDiskWriter(
+ blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics);
+ // Write the number of records
+ writeIntToBuffer(numRecordsToWrite, 0);
+ writer.write(writeBuffer, 0, 4);
+ }
+
+ // Based on DataOutputStream.writeLong.
+ private void writeLongToBuffer(long v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 56);
+ writeBuffer[offset + 1] = (byte)(v >>> 48);
+ writeBuffer[offset + 2] = (byte)(v >>> 40);
+ writeBuffer[offset + 3] = (byte)(v >>> 32);
+ writeBuffer[offset + 4] = (byte)(v >>> 24);
+ writeBuffer[offset + 5] = (byte)(v >>> 16);
+ writeBuffer[offset + 6] = (byte)(v >>> 8);
+ writeBuffer[offset + 7] = (byte)(v >>> 0);
+ }
+
+ // Based on DataOutputStream.writeInt.
+ private void writeIntToBuffer(int v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 24);
+ writeBuffer[offset + 1] = (byte)(v >>> 16);
+ writeBuffer[offset + 2] = (byte)(v >>> 8);
+ writeBuffer[offset + 3] = (byte)(v >>> 0);
+ }
+
+ /**
+ * Write a record to a spill file.
+ *
+ * @param baseObject the base object / memory page containing the record
+ * @param baseOffset the base offset which points directly to the record data.
+ * @param recordLength the length of the record.
+ * @param keyPrefix a sort key prefix
+ */
+ public void write(
+ Object baseObject,
+ long baseOffset,
+ int recordLength,
+ long keyPrefix) throws IOException {
+ if (numRecordsSpilled == numRecordsToWrite) {
+ throw new IllegalStateException(
+ "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite);
+ } else {
+ numRecordsSpilled++;
+ }
+ writeIntToBuffer(recordLength, 0);
+ writeLongToBuffer(keyPrefix, 4);
+ int dataRemaining = recordLength;
+ int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len
+ long recordReadPosition = baseOffset;
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
+ PlatformDependent.copyMemory(
+ baseObject,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+ toTransfer);
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE;
+ }
+ if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) {
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer));
+ }
+ writer.recordWritten();
+ }
+
+ public void close() throws IOException {
+ writer.commitAndClose();
+ writer = null;
+ writeBuffer = null;
+ }
+
+ public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
+ return new UnsafeSorterSpillReader(blockManager, file, blockId);
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
new file mode 100644
index 0000000000..ea8755e21e
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.UUID;
+
+import scala.Tuple2;
+import scala.Tuple2$;
+import scala.runtime.AbstractFunction1;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeExternalSorterSuite {
+
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+
+ File tempDir;
+
+ private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ return stream;
+ }
+ }
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ tempDir = new File(Utils.createTempDir$default$1());
+ taskContext = mock(TaskContext.class);
+ when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
+ when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
+ @Override
+ public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
+ TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (BlockId) args[0],
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
+ .then(returnsSecondArg());
+ }
+
+ private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
+ final int[] arr = new int[] { value };
+ sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
+ }
+
+ @Test
+ public void testSortingOnlyByPrefix() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ insertNumber(sorter, 5);
+ insertNumber(sorter, 1);
+ insertNumber(sorter, 3);
+ sorter.spill();
+ insertNumber(sorter, 4);
+ sorter.spill();
+ insertNumber(sorter, 2);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(i, iter.getKeyPrefix());
+ assertEquals(4, iter.getRecordLength());
+ // TODO: read rest of value.
+ }
+
+ // TODO: test for cleanup:
+ // assert(tempDir.isEmpty)
+ }
+
+ @Test
+ public void testSortingEmptyArrays() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(0, iter.getKeyPrefix());
+ assertEquals(0, iter.getRecordLength());
+ }
+ }
+
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
new file mode 100644
index 0000000000..9095009305
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Arrays;
+
+import org.junit.Test;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.*;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.mock;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeInMemorySorterSuite {
+
+ private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) {
+ final byte[] strBytes = new byte[length];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, length);
+ return new String(strBytes);
+ }
+
+ @Test
+ public void testSortingEmptyInput() {
+ final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
+ mock(RecordComparator.class),
+ mock(PrefixComparator.class),
+ 100);
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ assert(!iter.hasNext());
+ }
+
+ @Test
+ public void testSortingOnlyByIntegerPrefix() throws Exception {
+ final String[] dataToSort = new String[] {
+ "Boba",
+ "Pearls",
+ "Tapioca",
+ "Taho",
+ "Condensed Milk",
+ "Jasmine",
+ "Milk Tea",
+ "Lychee",
+ "Mango"
+ };
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ final Object baseObject = dataPage.getBaseObject();
+ // Write the records into the data page:
+ long position = dataPage.getBaseOffset();
+ for (String str : dataToSort) {
+ final byte[] strBytes = str.getBytes("utf-8");
+ PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+ position += 4;
+ PlatformDependent.copyMemory(
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ position,
+ strBytes.length);
+ position += strBytes.length;
+ }
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+ // Compute key prefixes based on the records' partition ids
+ final HashPartitioner hashPartitioner = new HashPartitioner(4);
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
+ prefixComparator, dataToSort.length);
+ // Given a page of records, insert those records into the sorter one-by-one:
+ position = dataPage.getBaseOffset();
+ for (int i = 0; i < dataToSort.length; i++) {
+ // position now points to the start of a record (which holds its length).
+ final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position);
+ final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
+ final String str = getStringFromDataPage(baseObject, position + 4, recordLength);
+ final int partitionId = hashPartitioner.getPartition(str);
+ sorter.insertRecord(address, partitionId);
+ position += 4 + recordLength;
+ }
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ int iterLength = 0;
+ long prevPrefix = -1;
+ Arrays.sort(dataToSort);
+ while (iter.hasNext()) {
+ iter.loadNext();
+ final String str =
+ getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength());
+ final long keyPrefix = iter.getKeyPrefix();
+ assertThat(str, isIn(Arrays.asList(dataToSort)));
+ assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));
+ prevPrefix = keyPrefix;
+ iterLength++;
+ }
+ assertEquals(dataToSort.length, iterLength);
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
new file mode 100644
index 0000000000..dd505dfa7d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort
+
+import org.scalatest.prop.PropertyChecks
+
+import org.apache.spark.SparkFunSuite
+
+class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
+
+ test("String prefix comparator") {
+
+ def testPrefixComparison(s1: String, s2: String): Unit = {
+ val s1Prefix = PrefixComparators.STRING.computePrefix(s1)
+ val s2Prefix = PrefixComparators.STRING.computePrefix(s2)
+ val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
+ assert(
+ (prefixComparisonResult == 0) ||
+ (prefixComparisonResult < 0 && s1 < s2) ||
+ (prefixComparisonResult > 0 && s1 > s2))
+ }
+
+ // scalastyle:off
+ val regressionTests = Table(
+ ("s1", "s2"),
+ ("abc", "世界"),
+ ("你好", "世界"),
+ ("你好123", "你好122")
+ )
+ // scalastyle:on
+
+ forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ }
+}