aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/pom.xml10
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java93
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java92
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java37
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java422
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java124
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java67
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java438
-rw-r--r--core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java75
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala205
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala2
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java101
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java132
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java527
-rw-r--r--core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala44
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala29
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala128
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala105
-rw-r--r--pom.xml14
-rw-r--r--project/MimaExcludes.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala28
-rw-r--r--unsafe/pom.xml4
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java79
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java23
33 files changed, 2767 insertions, 64 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 262a3320db..bfa49d0d6d 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -362,6 +362,16 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-library</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>
<scope>test</scope>
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
new file mode 100644
index 0000000000..3f746b886b
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+
+import scala.reflect.ClassTag;
+
+import org.apache.spark.serializer.DeserializationStream;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ * Our shuffle write path doesn't actually use this serializer (since we end up calling the
+ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ * around this, we pass a dummy no-op serializer.
+ */
+final class DummySerializerInstance extends SerializerInstance {
+
+ public static final DummySerializerInstance INSTANCE = new DummySerializerInstance();
+
+ private DummySerializerInstance() { }
+
+ @Override
+ public SerializationStream serializeStream(final OutputStream s) {
+ return new SerializationStream() {
+ @Override
+ public void flush() {
+ // Need to implement this because DiskObjectWriter uses it to flush the compression stream
+ try {
+ s.flush();
+ } catch (IOException e) {
+ PlatformDependent.throwException(e);
+ }
+ }
+
+ @Override
+ public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void close() {
+ // Need to implement this because DiskObjectWriter uses it to close the compression stream
+ try {
+ s.close();
+ } catch (IOException e) {
+ PlatformDependent.throwException(e);
+ }
+ }
+ };
+ }
+
+ @Override
+ public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DeserializationStream deserializeStream(InputStream s) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag<T> ev1) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
new file mode 100644
index 0000000000..4ee6a82c04
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+/**
+ * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
+ * <p>
+ * Within the long, the data is laid out as follows:
+ * <pre>
+ * [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * </pre>
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
+ * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
+ * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
+ * <p>
+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
+ * optimization to future work as it will require more careful design to ensure that addresses are
+ * properly aligned (e.g. by padding records).
+ */
+final class PackedRecordPointer {
+
+ static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes
+
+ /**
+ * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
+ */
+ static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+
+ /** Bit mask for the lower 40 bits of a long. */
+ private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
+
+ /** Bit mask for the upper 24 bits of a long */
+ private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
+
+ /** Bit mask for the lower 27 bits of a long. */
+ private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
+
+ /** Bit mask for the lower 51 bits of a long. */
+ private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
+
+ /** Bit mask for the upper 13 bits of a long */
+ private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+ /**
+ * Pack a record address and partition id into a single word.
+ *
+ * @param recordPointer a record pointer encoded by TaskMemoryManager.
+ * @param partitionId a shuffle partition id (maximum value of 2^24).
+ * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
+ */
+ public static long packPointer(long recordPointer, int partitionId) {
+ assert (partitionId <= MAXIMUM_PARTITION_ID);
+ // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
+ // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
+ final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
+ final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
+ return (((long) partitionId) << 40) | compressedAddress;
+ }
+
+ private long packedRecordPointer;
+
+ public void set(long packedRecordPointer) {
+ this.packedRecordPointer = packedRecordPointer;
+ }
+
+ public int getPartitionId() {
+ return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
+ }
+
+ public long getRecordPointer() {
+ final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
+ final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
+ return pageNumber | offsetInPage;
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
new file mode 100644
index 0000000000..7bac0dc0bb
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+
+import org.apache.spark.storage.TempShuffleBlockId;
+
+/**
+ * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
+ */
+final class SpillInfo {
+ final long[] partitionLengths;
+ final File file;
+ final TempShuffleBlockId blockId;
+
+ public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
+ this.partitionLengths = new long[numPartitions];
+ this.file = file;
+ this.blockId = blockId;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
new file mode 100644
index 0000000000..9e9ed94b78
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.LinkedList;
+
+import scala.Tuple2;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * An external sorter that is specialized for sort-based shuffle.
+ * <p>
+ * Incoming records are appended to data pages. When all records have been inserted (or when the
+ * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
+ * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
+ * written to a single output file (or multiple files, if we've spilled). The format of the output
+ * files is the same as the format of the final output file written by
+ * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
+ * written as a single serialized, compressed stream that can be read with a new decompression and
+ * deserialization stream.
+ * <p>
+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
+ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
+ * specialized merge procedure that avoids extra serialization/deserialization.
+ */
+final class UnsafeShuffleExternalSorter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
+
+ private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
+ @VisibleForTesting
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+ @VisibleForTesting
+ static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
+
+ private final int initialSize;
+ private final int numPartitions;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final BlockManager blockManager;
+ private final TaskContext taskContext;
+ private final ShuffleWriteMetrics writeMetrics;
+
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
+
+ private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
+
+ // These variables are reset after spilling:
+ private UnsafeShuffleInMemorySorter sorter;
+ private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+ private long freeSpaceInCurrentPage = 0;
+
+ public UnsafeShuffleExternalSorter(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ int initialSize,
+ int numPartitions,
+ SparkConf conf,
+ ShuffleWriteMetrics writeMetrics) throws IOException {
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.initialSize = initialSize;
+ this.numPartitions = numPartitions;
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+
+ this.writeMetrics = writeMetrics;
+ initializeForWriting();
+ }
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeForWriting() throws IOException {
+ // TODO: move this sizing calculation logic into a static method of sorter:
+ final long memoryRequested = initialSize * 8L;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+
+ this.sorter = new UnsafeShuffleInMemorySorter(initialSize);
+ }
+
+ /**
+ * Sorts the in-memory records and writes the sorted records to an on-disk file.
+ * This method does not free the sort data structures.
+ *
+ * @param isLastFile if true, this indicates that we're writing the final output file and that the
+ * bytes written should be counted towards shuffle spill metrics rather than
+ * shuffle write metrics.
+ */
+ private void writeSortedFile(boolean isLastFile) throws IOException {
+
+ final ShuffleWriteMetrics writeMetricsToUse;
+
+ if (isLastFile) {
+ // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
+ writeMetricsToUse = writeMetrics;
+ } else {
+ // We're spilling, so bytes written should be counted towards spill rather than write.
+ // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
+ // them towards shuffle bytes written.
+ writeMetricsToUse = new ShuffleWriteMetrics();
+ }
+
+ // This call performs the actual sort.
+ final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
+ sorter.getSortedIterator();
+
+ // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
+ // after SPARK-5581 is fixed.
+ BlockObjectWriter writer;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array. This array does not need to be large enough to hold a single
+ // record;
+ final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ // Because this output will be read during shuffle, its compression codec must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more details.
+ final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = spilledFileInfo._2();
+ final TempShuffleBlockId blockId = spilledFileInfo._1();
+ final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
+
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ final SerializerInstance ser = DummySerializerInstance.INSTANCE;
+
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+
+ int currentPartition = -1;
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+ assert (partition >= currentPartition);
+ if (partition != currentPartition) {
+ // Switch to the new partition
+ if (currentPartition != -1) {
+ writer.commitAndClose();
+ spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ }
+ currentPartition = partition;
+ writer =
+ blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+ }
+
+ final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
+ final Object recordPage = memoryManager.getPage(recordPointer);
+ final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
+ int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
+ long recordReadPosition = recordOffsetInPage + 4; // skip over record length
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
+ PlatformDependent.copyMemory(
+ recordPage,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ toTransfer);
+ writer.write(writeBuffer, 0, toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ }
+ writer.recordWritten();
+ }
+
+ if (writer != null) {
+ writer.commitAndClose();
+ // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
+ // then the file might be empty. Note that it might be better to avoid calling
+ // writeSortedFile() in that case.
+ if (currentPartition != -1) {
+ spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ spills.add(spillInfo);
+ }
+ }
+
+ if (!isLastFile) { // i.e. this is a spill file
+ // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
+ // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
+ // relies on its `recordWritten()` method being called in order to trigger periodic updates to
+ // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
+ // counter at a higher-level, then the in-progress metrics for records written and bytes
+ // written would get out of sync.
+ //
+ // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
+ // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
+ // metrics to the true write metrics here. The reason for performing this copying is so that
+ // we can avoid reporting spilled bytes as shuffle write bytes.
+ //
+ // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
+ // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
+ // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
+ writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
+ taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
+ }
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spills.size(),
+ spills.size() > 1 ? " times" : " time");
+
+ writeSortedFile(false);
+ final long sorterMemoryUsage = sorter.getMemoryUsage();
+ sorter = null;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
+ initializeForWriting();
+ }
+
+ private long getMemoryUsage() {
+ return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+ }
+
+ private long freeMemory() {
+ long memoryFreed = 0;
+ for (MemoryBlock block : allocatedPages) {
+ memoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ allocatedPages.clear();
+ currentPage = null;
+ currentPagePosition = -1;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ /**
+ * Force all memory and spill files to be deleted; called by shuffle error-handling code.
+ */
+ public void cleanupAfterError() {
+ freeMemory();
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && !spill.file.delete()) {
+ logger.error("Unable to delete spill file {}", spill.file.getPath());
+ }
+ }
+ if (sorter != null) {
+ shuffleMemoryManager.release(sorter.getMemoryUsage());
+ sorter = null;
+ }
+ }
+
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+
+ * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ */
+ private boolean haveSpaceForRecord(int requiredSpace) {
+ assert (requiredSpace > 0);
+ return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+ */
+ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to expand sort pointer array");
+ final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ sorter.expandPointerArray();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+ if (requiredSpace > freeSpaceInCurrentPage) {
+ logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+ freeSpaceInCurrentPage);
+ // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+ // without using the free space at the end of the current page. We should also do this for
+ // BytesToBytesMap.
+ if (requiredSpace > PAGE_SIZE) {
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+ PAGE_SIZE + ")");
+ } else {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquired < PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+ }
+ }
+ currentPage = memoryManager.allocatePage(PAGE_SIZE);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = PAGE_SIZE;
+ allocatedPages.add(currentPage);
+ }
+ }
+ }
+
+ /**
+ * Write a record to the shuffle sorter.
+ */
+ public void insertRecord(
+ Object recordBaseObject,
+ long recordBaseOffset,
+ int lengthInBytes,
+ int partitionId) throws IOException {
+ // Need 4 bytes to store the record length.
+ final int totalSpaceRequired = lengthInBytes + 4;
+ if (!haveSpaceForRecord(totalSpaceRequired)) {
+ allocateSpaceForRecord(totalSpaceRequired);
+ }
+
+ final long recordAddress =
+ memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ final Object dataPageBaseObject = currentPage.getBaseObject();
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+ currentPagePosition += 4;
+ freeSpaceInCurrentPage -= 4;
+ PlatformDependent.copyMemory(
+ recordBaseObject,
+ recordBaseOffset,
+ dataPageBaseObject,
+ currentPagePosition,
+ lengthInBytes);
+ currentPagePosition += lengthInBytes;
+ freeSpaceInCurrentPage -= lengthInBytes;
+ sorter.insertRecord(recordAddress, partitionId);
+ }
+
+ /**
+ * Close the sorter, causing any buffered data to be sorted and written out to disk.
+ *
+ * @return metadata for the spill files written by this sorter. If no records were ever inserted
+ * into this sorter, then this will return an empty array.
+ * @throws IOException
+ */
+ public SpillInfo[] closeAndGetSpills() throws IOException {
+ try {
+ if (sorter != null) {
+ // Do not count the final file towards the spill count.
+ writeSortedFile(true);
+ freeMemory();
+ }
+ return spills.toArray(new SpillInfo[spills.size()]);
+ } catch (IOException e) {
+ cleanupAfterError();
+ throw e;
+ }
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
new file mode 100644
index 0000000000..5bab501da9
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Comparator;
+
+import org.apache.spark.util.collection.Sorter;
+
+final class UnsafeShuffleInMemorySorter {
+
+ private final Sorter<PackedRecordPointer, long[]> sorter;
+ private static final class SortComparator implements Comparator<PackedRecordPointer> {
+ @Override
+ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
+ return left.getPartitionId() - right.getPartitionId();
+ }
+ }
+ private static final SortComparator SORT_COMPARATOR = new SortComparator();
+
+ /**
+ * An array of record pointers and partition ids that have been encoded by
+ * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
+ * records.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the pointer array where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public UnsafeShuffleInMemorySorter(int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize];
+ this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 1 < pointerArray.length;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ /**
+ * Inserts a record to be sorted.
+ *
+ * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
+ * certain pointer compression techniques used by the sorter, the sort can
+ * only operate on pointers that point to locations in the first
+ * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
+ * @param partitionId the partition id, which must be less than or equal to
+ * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
+ */
+ public void insertRecord(long recordPointer, int partitionId) {
+ if (!hasSpaceForAnotherRecord()) {
+ if (pointerArray.length == Integer.MAX_VALUE) {
+ throw new IllegalStateException("Sort pointer array has reached maximum size");
+ } else {
+ expandPointerArray();
+ }
+ }
+ pointerArray[pointerArrayInsertPosition] =
+ PackedRecordPointer.packPointer(recordPointer, partitionId);
+ pointerArrayInsertPosition++;
+ }
+
+ /**
+ * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
+ */
+ public static final class UnsafeShuffleSorterIterator {
+
+ private final long[] pointerArray;
+ private final int numRecords;
+ final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
+ private int position = 0;
+
+ public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
+ this.numRecords = numRecords;
+ this.pointerArray = pointerArray;
+ }
+
+ public boolean hasNext() {
+ return position < numRecords;
+ }
+
+ public void loadNext() {
+ packedRecordPointer.set(pointerArray[position]);
+ position++;
+ }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order.
+ */
+ public UnsafeShuffleSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
+ return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
new file mode 100644
index 0000000000..a66d74ee44
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
+
+ public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
+
+ private UnsafeShuffleSortDataFormat() { }
+
+ @Override
+ public PackedRecordPointer getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public PackedRecordPointer newKey() {
+ return new PackedRecordPointer();
+ }
+
+ @Override
+ public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
+ reuse.set(data[pos]);
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ final long temp = data[pos0];
+ data[pos0] = data[pos1];
+ data[pos1] = temp;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos] = src[srcPos];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos, dst, dstPos, length);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ return new long[length];
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
new file mode 100644
index 0000000000..ad7eb04afc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+import javax.annotation.Nullable;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.JavaConversions;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.*;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+@Private
+public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
+
+ private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
+
+ @VisibleForTesting
+ static final int INITIAL_SORT_BUFFER_SIZE = 4096;
+
+ private final BlockManager blockManager;
+ private final IndexShuffleBlockResolver shuffleBlockResolver;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final SerializerInstance serializer;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final int shuffleId;
+ private final int mapId;
+ private final TaskContext taskContext;
+ private final SparkConf sparkConf;
+ private final boolean transferToEnabled;
+
+ private MapStatus mapStatus = null;
+ private UnsafeShuffleExternalSorter sorter = null;
+
+ /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
+ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
+ public MyByteArrayOutputStream(int size) { super(size); }
+ public byte[] getBuf() { return buf; }
+ }
+
+ private MyByteArrayOutputStream serBuffer;
+ private SerializationStream serOutputStream;
+
+ /**
+ * Are we in the process of stopping? Because map tasks can call stop() with success = true
+ * and then call stop() with success = false if they get an exception, we want to make sure
+ * we don't try deleting files, etc twice.
+ */
+ private boolean stopping = false;
+
+ public UnsafeShuffleWriter(
+ BlockManager blockManager,
+ IndexShuffleBlockResolver shuffleBlockResolver,
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ UnsafeShuffleHandle<K, V> handle,
+ int mapId,
+ TaskContext taskContext,
+ SparkConf sparkConf) throws IOException {
+ final int numPartitions = handle.dependency().partitioner().numPartitions();
+ if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
+ throw new IllegalArgumentException(
+ "UnsafeShuffleWriter can only be used for shuffles with at most " +
+ UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
+ }
+ this.blockManager = blockManager;
+ this.shuffleBlockResolver = shuffleBlockResolver;
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.mapId = mapId;
+ final ShuffleDependency<K, V, V> dep = handle.dependency();
+ this.shuffleId = dep.shuffleId();
+ this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
+ this.partitioner = dep.partitioner();
+ this.writeMetrics = new ShuffleWriteMetrics();
+ taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+ this.taskContext = taskContext;
+ this.sparkConf = sparkConf;
+ this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
+ open();
+ }
+
+ /**
+ * This convenience method should only be called in test code.
+ */
+ @VisibleForTesting
+ public void write(Iterator<Product2<K, V>> records) throws IOException {
+ write(JavaConversions.asScalaIterator(records));
+ }
+
+ @Override
+ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
+ boolean success = false;
+ try {
+ while (records.hasNext()) {
+ insertRecordIntoSorter(records.next());
+ }
+ closeAndWriteOutput();
+ success = true;
+ } finally {
+ if (!success) {
+ sorter.cleanupAfterError();
+ }
+ }
+ }
+
+ private void open() throws IOException {
+ assert (sorter == null);
+ sorter = new UnsafeShuffleExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ INITIAL_SORT_BUFFER_SIZE,
+ partitioner.numPartitions(),
+ sparkConf,
+ writeMetrics);
+ serBuffer = new MyByteArrayOutputStream(1024 * 1024);
+ serOutputStream = serializer.serializeStream(serBuffer);
+ }
+
+ @VisibleForTesting
+ void closeAndWriteOutput() throws IOException {
+ serBuffer = null;
+ serOutputStream = null;
+ final SpillInfo[] spills = sorter.closeAndGetSpills();
+ sorter = null;
+ final long[] partitionLengths;
+ try {
+ partitionLengths = mergeSpills(spills);
+ } finally {
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && ! spill.file.delete()) {
+ logger.error("Error while deleting spill file {}", spill.file.getPath());
+ }
+ }
+ }
+ shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ }
+
+ @VisibleForTesting
+ void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+ final K key = record._1();
+ final int partitionId = partitioner.getPartition(key);
+ serBuffer.reset();
+ serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
+ serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
+ serOutputStream.flush();
+
+ final int serializedRecordSize = serBuffer.size();
+ assert (serializedRecordSize > 0);
+
+ sorter.insertRecord(
+ serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
+ }
+
+ @VisibleForTesting
+ void forceSorterToSpill() throws IOException {
+ assert (sorter != null);
+ sorter.spill();
+ }
+
+ /**
+ * Merge zero or more spill files together, choosing the fastest merging strategy based on the
+ * number of spills and the IO compression codec.
+ *
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+ final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
+ final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
+ final boolean fastMergeEnabled =
+ sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
+ final boolean fastMergeIsSupported =
+ !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
+ try {
+ if (spills.length == 0) {
+ new FileOutputStream(outputFile).close(); // Create an empty file
+ return new long[partitioner.numPartitions()];
+ } else if (spills.length == 1) {
+ // Here, we don't need to perform any metrics updates because the bytes written to this
+ // output file would have already been counted as shuffle bytes written.
+ Files.move(spills[0].file, outputFile);
+ return spills[0].partitionLengths;
+ } else {
+ final long[] partitionLengths;
+ // There are multiple spills to merge, so none of these spill files' lengths were counted
+ // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+ // then the final output file's size won't necessarily be equal to the sum of the spill
+ // files' sizes. To guard against this case, we look at the output file's actual size when
+ // computing shuffle bytes written.
+ //
+ // We allow the individual merge methods to report their own IO times since different merge
+ // strategies use different IO techniques. We count IO during merge towards the shuffle
+ // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
+ // branch in ExternalSorter.
+ if (fastMergeEnabled && fastMergeIsSupported) {
+ // Compression is disabled or we are using an IO compression codec that supports
+ // decompression of concatenated compressed streams, so we can perform a fast spill merge
+ // that doesn't need to interpret the spilled bytes.
+ if (transferToEnabled) {
+ logger.debug("Using transferTo-based fast merge");
+ partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+ } else {
+ logger.debug("Using fileStream-based fast merge");
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+ }
+ } else {
+ logger.debug("Using slow merge");
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ }
+ // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
+ // in-memory records, we write out the in-memory records to a file but do not count that
+ // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
+ // to be counted as shuffle write, but this will lead to double-counting of the final
+ // SpillInfo's bytes.
+ writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
+ writeMetrics.incShuffleBytesWritten(outputFile.length());
+ return partitionLengths;
+ }
+ } catch (IOException e) {
+ if (outputFile.exists() && !outputFile.delete()) {
+ logger.error("Unable to delete output file {}", outputFile.getPath());
+ }
+ throw e;
+ }
+ }
+
+ /**
+ * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
+ * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
+ * cases where the IO compression codec does not support concatenation of compressed data, or in
+ * cases where users have explicitly disabled use of {@code transferTo} in order to work around
+ * kernel bugs.
+ *
+ * @param spills the spills to merge.
+ * @param outputFile the file to write the merged data to.
+ * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpillsWithFileStream(
+ SpillInfo[] spills,
+ File outputFile,
+ @Nullable CompressionCodec compressionCodec) throws IOException {
+ assert (spills.length >= 2);
+ final int numPartitions = partitioner.numPartitions();
+ final long[] partitionLengths = new long[numPartitions];
+ final InputStream[] spillInputStreams = new FileInputStream[spills.length];
+ OutputStream mergedFileOutputStream = null;
+
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ spillInputStreams[i] = new FileInputStream(spills[i].file);
+ }
+ for (int partition = 0; partition < numPartitions; partition++) {
+ final long initialFileLength = outputFile.length();
+ mergedFileOutputStream =
+ new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
+ if (compressionCodec != null) {
+ mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
+ }
+
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ if (partitionLengthInSpill > 0) {
+ InputStream partitionInputStream =
+ new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
+ if (compressionCodec != null) {
+ partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ }
+ ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+ }
+ }
+ mergedFileOutputStream.flush();
+ mergedFileOutputStream.close();
+ partitionLengths[partition] = (outputFile.length() - initialFileLength);
+ }
+ threwException = false;
+ } finally {
+ // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+ // throw exceptions during cleanup if threwException == false.
+ for (InputStream stream : spillInputStreams) {
+ Closeables.close(stream, threwException);
+ }
+ Closeables.close(mergedFileOutputStream, threwException);
+ }
+ return partitionLengths;
+ }
+
+ /**
+ * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
+ * This is only safe when the IO compression codec and serializer support concatenation of
+ * serialized streams.
+ *
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
+ assert (spills.length >= 2);
+ final int numPartitions = partitioner.numPartitions();
+ final long[] partitionLengths = new long[numPartitions];
+ final FileChannel[] spillInputChannels = new FileChannel[spills.length];
+ final long[] spillInputChannelPositions = new long[spills.length];
+ FileChannel mergedFileOutputChannel = null;
+
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
+ }
+ // This file needs to opened in append mode in order to work around a Linux kernel bug that
+ // affects transferTo; see SPARK-3948 for more details.
+ mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
+
+ long bytesWrittenToMergedFile = 0;
+ for (int partition = 0; partition < numPartitions; partition++) {
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ long bytesToTransfer = partitionLengthInSpill;
+ final FileChannel spillInputChannel = spillInputChannels[i];
+ final long writeStartTime = System.nanoTime();
+ while (bytesToTransfer > 0) {
+ final long actualBytesTransferred = spillInputChannel.transferTo(
+ spillInputChannelPositions[i],
+ bytesToTransfer,
+ mergedFileOutputChannel);
+ spillInputChannelPositions[i] += actualBytesTransferred;
+ bytesToTransfer -= actualBytesTransferred;
+ }
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+ bytesWrittenToMergedFile += partitionLengthInSpill;
+ partitionLengths[partition] += partitionLengthInSpill;
+ }
+ }
+ // Check the position after transferTo loop to see if it is in the right position and raise an
+ // exception if it is incorrect. The position will not be increased to the expected length
+ // after calling transferTo in kernel version 2.6.32. This issue is described at
+ // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+ if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
+ throw new IOException(
+ "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
+ "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
+ " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
+ "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
+ "to disable this NIO feature."
+ );
+ }
+ threwException = false;
+ } finally {
+ // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+ // throw exceptions during cleanup if threwException == false.
+ for (int i = 0; i < spills.length; i++) {
+ assert(spillInputChannelPositions[i] == spills[i].file.length());
+ Closeables.close(spillInputChannels[i], threwException);
+ }
+ Closeables.close(mergedFileOutputChannel, threwException);
+ }
+ return partitionLengths;
+ }
+
+ @Override
+ public Option<MapStatus> stop(boolean success) {
+ try {
+ if (stopping) {
+ return Option.apply(null);
+ } else {
+ stopping = true;
+ if (success) {
+ if (mapStatus == null) {
+ throw new IllegalStateException("Cannot call stop(true) without having called write()");
+ }
+ return Option.apply(mapStatus);
+ } else {
+ // The map task failed, so delete our output data.
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+ return Option.apply(null);
+ }
+ }
+ } finally {
+ if (sorter != null) {
+ // If sorter is non-null, then this implies that we called stop() in response to an error,
+ // so we need to clean up memory and spill files created by the sorter
+ sorter.cleanupAfterError();
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
new file mode 100644
index 0000000000..dc2aa30466
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+
+/**
+ * Intercepts write calls and tracks total time spent writing in order to update shuffle write
+ * metrics. Not thread safe.
+ */
+@Private
+public final class TimeTrackingOutputStream extends OutputStream {
+
+ private final ShuffleWriteMetrics writeMetrics;
+ private final OutputStream outputStream;
+
+ public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) {
+ this.writeMetrics = writeMetrics;
+ this.outputStream = outputStream;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void write(byte[] b) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b, off, len);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.flush();
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void close() throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.close();
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0c4d28f786..a5d831c7e6 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -313,7 +313,8 @@ object SparkEnv extends Logging {
// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
- "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
+ "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
+ "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index dfbde7c8a1..698d1384d5 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -121,6 +121,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true)
+ protected def this() = this(new SparkConf()) // For deserialization only
+
override def newInstance(): SerializerInstance = {
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6ad427bcac..6c3b3080d2 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
private val consolidateShuffleFiles =
conf.getBoolean("spark.shuffle.consolidateFiles", false)
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
index f6e6fe5def..4cc4ef5f18 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -17,14 +17,17 @@
package org.apache.spark.shuffle
+import java.io.IOException
+
import org.apache.spark.scheduler.MapStatus
/**
* Obtained inside a map task to write out records to the shuffle system.
*/
-private[spark] trait ShuffleWriter[K, V] {
+private[spark] abstract class ShuffleWriter[K, V] {
/** Write a sequence of records to this task's output */
- def write(records: Iterator[_ <: Product2[K, V]]): Unit
+ @throws[IOException]
+ def write(records: Iterator[Product2[K, V]]): Unit
/** Close this writer, passing along whether the map completed */
def stop(success: Boolean): Option[MapStatus]
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 897f0a5dc5..eb87cee159 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V](
writeMetrics)
/** Write a bunch of records to this task's output */
- override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ override def write(records: Iterator[Product2[K, V]]): Unit = {
val iter = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
dep.aggregator.get.combineValuesByKey(records, context)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 15842941da..d7fab351ca 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
true
}
- override def shuffleBlockResolver: IndexShuffleBlockResolver = {
+ override val shuffleBlockResolver: IndexShuffleBlockResolver = {
indexShuffleBlockResolver
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index add2656294..c9dd6bfc4c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -48,7 +48,7 @@ private[spark] class SortShuffleWriter[K, V, C](
context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
/** Write a bunch of records to this task's output */
- override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ override def write(records: Iterator[Product2[K, V]]): Unit = {
if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
sorter = new ExternalSorter[K, V, C](
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
new file mode 100644
index 0000000000..f2bfef376d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.util.Collections
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.SortShuffleManager
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
+ */
+private[spark] class UnsafeShuffleHandle[K, V](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+private[spark] object UnsafeShuffleManager extends Logging {
+
+ /**
+ * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
+ */
+ val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+ /**
+ * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
+ * path or whether it should fall back to the original sort-based shuffle.
+ */
+ def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
+ val shufId = dependency.shuffleId
+ val serializer = Serializer.getSerializer(dependency.serializer)
+ if (!serializer.supportsRelocationOfSerializedObjects) {
+ log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
+ s"${serializer.getClass.getName}, does not support object relocation")
+ false
+ } else if (dependency.aggregator.isDefined) {
+ log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
+ false
+ } else if (dependency.keyOrdering.isDefined) {
+ log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined")
+ false
+ } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
+ log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
+ s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
+ false
+ } else {
+ log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
+ true
+ }
+ }
+}
+
+/**
+ * A shuffle implementation that uses directly-managed memory to implement several performance
+ * optimizations for certain types of shuffles. In cases where the new performance optimizations
+ * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
+ * shuffles.
+ *
+ * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
+ *
+ * - The shuffle dependency specifies no aggregation or output ordering.
+ * - The shuffle serializer supports relocation of serialized values (this is currently supported
+ * by KryoSerializer and Spark SQL's custom serializers).
+ * - The shuffle produces fewer than 16777216 output partitions.
+ * - No individual record is larger than 128 MB when serialized.
+ *
+ * In addition, extra spill-merging optimizations are automatically applied when the shuffle
+ * compression codec supports concatenation of serialized streams. This is currently supported by
+ * Spark's LZF serializer.
+ *
+ * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * UnsafeShuffleManager optimizes this process in several ways:
+ *
+ * - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ * consumption and GC overheads. This optimization requires the record serializer to have certain
+ * properties to allow serialized records to be re-ordered without requiring deserialization.
+ * See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
+ * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ * record in the sorting array, this fits more of the array into cache.
+ *
+ * - The spill merging procedure operates on blocks of serialized records that belong to the same
+ * partition and does not need to deserialize records during the merge.
+ *
+ * - When the spill compression codec supports concatenation of compressed data, the spill merge
+ * simply concatenates the serialized and compressed spill partitions to produce the final output
+ * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ * and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on UnsafeShuffleManager's design, see SPARK-7081.
+ */
+private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+
+ if (!conf.getBoolean("spark.shuffle.spill", true)) {
+ logWarning(
+ "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
+ "manager; its optimized shuffles will continue to spill to disk when necessary.")
+ }
+
+ private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
+ private[this] val shufflesThatFellBackToSortShuffle =
+ Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
+ private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
+
+ /**
+ * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
+ new UnsafeShuffleHandle[K, V](
+ shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else {
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
+ }
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C] = {
+ sortShuffleManager.getReader(handle, startPartition, endPartition, context)
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Int,
+ context: TaskContext): ShuffleWriter[K, V] = {
+ handle match {
+ case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] =>
+ numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
+ val env = SparkEnv.get
+ new UnsafeShuffleWriter(
+ env.blockManager,
+ shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+ context.taskMemoryManager(),
+ env.shuffleMemoryManager,
+ unsafeShuffleHandle,
+ mapId,
+ context,
+ env.conf)
+ case other =>
+ shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
+ sortShuffleManager.getWriter(handle, mapId, context)
+ }
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
+ sortShuffleManager.unregisterShuffle(shuffleId)
+ } else {
+ Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
+ (0 until numMaps).foreach { mapId =>
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+ }
+ }
+ true
+ }
+ }
+
+ override val shuffleBlockResolver: IndexShuffleBlockResolver = {
+ sortShuffleManager.shuffleBlockResolver
+ }
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {
+ sortShuffleManager.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 8bc4e205bc..a33f22ef52 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter(
extends BlockObjectWriter(blockId)
with Logging
{
- /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
- private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
- override def write(i: Int): Unit = callWithTiming(out.write(i))
- override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b))
- override def write(b: Array[Byte], off: Int, len: Int): Unit = {
- callWithTiming(out.write(b, off, len))
- }
- override def close(): Unit = out.close()
- override def flush(): Unit = out.flush()
- }
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
@@ -136,7 +126,7 @@ private[spark] class DiskBlockObjectWriter(
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
fos = new FileOutputStream(file, true)
- ts = new TimeTrackingOutputStream(fos)
+ ts = new TimeTrackingOutputStream(writeMetrics, fos)
channel = fos.getChannel()
bs = compressStream(new BufferedOutputStream(ts, bufferSize))
objOut = serializerInstance.serializeStream(bs)
@@ -150,9 +140,9 @@ private[spark] class DiskBlockObjectWriter(
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
- callWithTiming {
- fos.getFD.sync()
- }
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
}
} {
objOut.close()
@@ -251,12 +241,6 @@ private[spark] class DiskBlockObjectWriter(
reportedPosition = pos
}
- private def callWithTiming(f: => Unit) = {
- val start = System.nanoTime()
- f
- writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
- }
-
// For testing
private[spark] override def flush() {
objOut.flush()
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index b850973145..df2d6ad3b4 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -90,7 +90,7 @@ class ExternalAppendOnlyMap[K, V, C](
// Number of bytes spilled in total
private var _diskBytesSpilled = 0L
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize =
sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 7d5cf7b61e..3b9d14f937 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -110,7 +110,7 @@ private[spark] class ExternalSorter[K, V, C](
private val conf = SparkEnv.get.conf
private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
new file mode 100644
index 0000000000..db9e827590
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
+
+public class PackedRecordPointerSuite {
+
+ @Test
+ public void heap() {
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock page0 = memoryManager.allocatePage(100);
+ final MemoryBlock page1 = memoryManager.allocatePage(100);
+ final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+ page1.getBaseOffset() + 42);
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+ assertEquals(360, packedPointer.getPartitionId());
+ final long recordPointer = packedPointer.getRecordPointer();
+ assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+ assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+ assertEquals(addressInPage1, recordPointer);
+ memoryManager.cleanUpAllAllocatedMemory();
+ }
+
+ @Test
+ public void offHeap() {
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+ final MemoryBlock page0 = memoryManager.allocatePage(100);
+ final MemoryBlock page1 = memoryManager.allocatePage(100);
+ final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+ page1.getBaseOffset() + 42);
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+ assertEquals(360, packedPointer.getPartitionId());
+ final long recordPointer = packedPointer.getRecordPointer();
+ assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+ assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+ assertEquals(addressInPage1, recordPointer);
+ memoryManager.cleanUpAllAllocatedMemory();
+ }
+
+ @Test
+ public void maximumPartitionIdCanBeEncoded() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
+ assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+ }
+
+ @Test
+ public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ try {
+ // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+ packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
+ assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId());
+ } catch (AssertionError e ) {
+ // pass
+ }
+ }
+
+ @Test
+ public void maximumOffsetInPageCanBeEncoded() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
+ packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+ assertEquals(address, packedPointer.getRecordPointer());
+ }
+
+ @Test
+ public void offsetsPastMaxOffsetInPageWillOverflow() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
+ packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+ assertEquals(0, packedPointer.getRecordPointer());
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
new file mode 100644
index 0000000000..8fa72597db
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeShuffleInMemorySorterSuite {
+
+ private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
+ final byte[] strBytes = new byte[strLength];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
+ return new String(strBytes);
+ }
+
+ @Test
+ public void testSortingEmptyInput() {
+ final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
+ final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ assert(!iter.hasNext());
+ }
+
+ @Test
+ public void testBasicSorting() throws Exception {
+ final String[] dataToSort = new String[] {
+ "Boba",
+ "Pearls",
+ "Tapioca",
+ "Taho",
+ "Condensed Milk",
+ "Jasmine",
+ "Milk Tea",
+ "Lychee",
+ "Mango"
+ };
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ final Object baseObject = dataPage.getBaseObject();
+ final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+ final HashPartitioner hashPartitioner = new HashPartitioner(4);
+
+ // Write the records into the data page and store pointers into the sorter
+ long position = dataPage.getBaseOffset();
+ for (String str : dataToSort) {
+ final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
+ final byte[] strBytes = str.getBytes("utf-8");
+ PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+ position += 4;
+ PlatformDependent.copyMemory(
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ position,
+ strBytes.length);
+ position += strBytes.length;
+ sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
+ }
+
+ // Sort the records
+ final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ int prevPartitionId = -1;
+ Arrays.sort(dataToSort);
+ for (int i = 0; i < dataToSort.length; i++) {
+ Assert.assertTrue(iter.hasNext());
+ iter.loadNext();
+ final int partitionId = iter.packedRecordPointer.getPartitionId();
+ Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
+ Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
+ partitionId >= prevPartitionId);
+ final long recordAddress = iter.packedRecordPointer.getRecordPointer();
+ final int recordLength = PlatformDependent.UNSAFE.getInt(
+ memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
+ final String str = getStringFromDataPage(
+ memoryManager.getPage(recordAddress),
+ memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
+ recordLength);
+ Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
+ }
+ Assert.assertFalse(iter.hasNext());
+ }
+
+ @Test
+ public void testSortingManyNumbers() throws Exception {
+ UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+ int[] numbersToSort = new int[128000];
+ Random random = new Random(16);
+ for (int i = 0; i < numbersToSort.length; i++) {
+ numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
+ sorter.insertRecord(0, numbersToSort[i]);
+ }
+ Arrays.sort(numbersToSort);
+ int[] sorterResult = new int[numbersToSort.length];
+ UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ int j = 0;
+ while (iter.hasNext()) {
+ iter.loadNext();
+ sorterResult[j] = iter.packedRecordPointer.getPartitionId();
+ j += 1;
+ }
+ Assert.assertArrayEquals(numbersToSort, sorterResult);
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
new file mode 100644
index 0000000000..730d265c87
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import scala.*;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+import scala.runtime.AbstractFunction1;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.io.ByteStreams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZ4CompressionCodec;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.io.SnappyCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.serializer.*;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeShuffleWriterSuite {
+
+ static final int NUM_PARTITITONS = 4;
+ final TaskMemoryManager taskMemoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
+ File mergedOutputFile;
+ File tempDir;
+ long[] partitionSizesInMergedFile;
+ final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+ SparkConf conf;
+ final Serializer serializer = new KryoSerializer(new SparkConf());
+ TaskMetrics taskMetrics;
+
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
+
+ private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
+ } else {
+ return stream;
+ }
+ }
+ }
+
+ @After
+ public void tearDown() {
+ Utils.deleteRecursively(tempDir);
+ final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+ if (leakedMemory != 0) {
+ fail("Test leaked " + leakedMemory + " bytes of managed memory");
+ }
+ }
+
+ @Before
+ @SuppressWarnings("unchecked")
+ public void setUp() throws IOException {
+ MockitoAnnotations.initMocks(this);
+ tempDir = Utils.createTempDir("test", "test");
+ mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
+ partitionSizesInMergedFile = null;
+ spillFilesCreated.clear();
+ conf = new SparkConf();
+ taskMetrics = new TaskMetrics();
+
+ when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (BlockId) args[0],
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
+ new Answer<InputStream>() {
+ @Override
+ public InputStream answer(InvocationOnMock invocation) throws Throwable {
+ assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+ InputStream is = (InputStream) invocation.getArguments()[1];
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
+ } else {
+ return is;
+ }
+ }
+ }
+ );
+
+ when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
+ new Answer<OutputStream>() {
+ @Override
+ public OutputStream answer(InvocationOnMock invocation) throws Throwable {
+ assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+ OutputStream os = (OutputStream) invocation.getArguments()[1];
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
+ } else {
+ return os;
+ }
+ }
+ }
+ );
+
+ when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+ return null;
+ }
+ }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+
+ when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+ new Answer<Tuple2<TempShuffleBlockId, File>>() {
+ @Override
+ public Tuple2<TempShuffleBlockId, File> answer(
+ InvocationOnMock invocationOnMock) throws Throwable {
+ TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ spillFilesCreated.add(file);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+
+ when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
+ when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+ }
+
+ private UnsafeShuffleWriter<Object, Object> createWriter(
+ boolean transferToEnabled) throws IOException {
+ conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
+ return new UnsafeShuffleWriter<Object, Object>(
+ blockManager,
+ shuffleBlockResolver,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+ 0, // map id
+ taskContext,
+ conf
+ );
+ }
+
+ private void assertSpillFilesWereCleanedUp() {
+ for (File spillFile : spillFilesCreated) {
+ assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+ spillFile.exists());
+ }
+ }
+
+ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
+ final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
+ long startOffset = 0;
+ for (int i = 0; i < NUM_PARTITITONS; i++) {
+ final long partitionSize = partitionSizesInMergedFile[i];
+ if (partitionSize > 0) {
+ InputStream in = new FileInputStream(mergedOutputFile);
+ ByteStreams.skipFully(in, startOffset);
+ in = new LimitedInputStream(in, partitionSize);
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
+ }
+ DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
+ Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
+ while (records.hasNext()) {
+ Tuple2<Object, Object> record = records.next();
+ assertEquals(i, hashPartitioner.getPartition(record._1()));
+ recordsList.add(record);
+ }
+ recordsStream.close();
+ startOffset += partitionSize;
+ }
+ }
+ return recordsList;
+ }
+
+ @Test(expected=IllegalStateException.class)
+ public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+ createWriter(false).stop(true);
+ }
+
+ @Test
+ public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+ createWriter(false).stop(false);
+ }
+
+ @Test
+ public void writeEmptyIterator() throws Exception {
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+ writer.write(Collections.<Product2<Object, Object>>emptyIterator());
+ final Option<MapStatus> mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+ assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
+ assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
+ assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
+ assertEquals(0, taskMetrics.diskBytesSpilled());
+ assertEquals(0, taskMetrics.memoryBytesSpilled());
+ }
+
+ @Test
+ public void writeWithoutSpilling() throws Exception {
+ // In this example, each partition should have exactly one record:
+ final ArrayList<Product2<Object, Object>> dataToWrite =
+ new ArrayList<Product2<Object, Object>>();
+ for (int i = 0; i < NUM_PARTITITONS; i++) {
+ dataToWrite.add(new Tuple2<Object, Object>(i, i));
+ }
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+ writer.write(dataToWrite.iterator());
+ final Option<MapStatus> mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+
+ long sumOfPartitionSizes = 0;
+ for (long size: partitionSizesInMergedFile) {
+ // All partitions should be the same size:
+ assertEquals(partitionSizesInMergedFile[0], size);
+ sumOfPartitionSizes += size;
+ }
+ assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertEquals(0, taskMetrics.diskBytesSpilled());
+ assertEquals(0, taskMetrics.memoryBytesSpilled());
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ private void testMergingSpills(
+ boolean transferToEnabled,
+ String compressionCodecName) throws IOException {
+ if (compressionCodecName != null) {
+ conf.set("spark.shuffle.compress", "true");
+ conf.set("spark.io.compression.codec", compressionCodecName);
+ } else {
+ conf.set("spark.shuffle.compress", "false");
+ }
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
+ final ArrayList<Product2<Object, Object>> dataToWrite =
+ new ArrayList<Product2<Object, Object>>();
+ for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
+ dataToWrite.add(new Tuple2<Object, Object>(i, i));
+ }
+ writer.insertRecordIntoSorter(dataToWrite.get(0));
+ writer.insertRecordIntoSorter(dataToWrite.get(1));
+ writer.insertRecordIntoSorter(dataToWrite.get(2));
+ writer.insertRecordIntoSorter(dataToWrite.get(3));
+ writer.forceSorterToSpill();
+ writer.insertRecordIntoSorter(dataToWrite.get(4));
+ writer.insertRecordIntoSorter(dataToWrite.get(5));
+ writer.closeAndWriteOutput();
+ final Option<MapStatus> mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+ assertEquals(2, spillFilesCreated.size());
+
+ long sumOfPartitionSizes = 0;
+ for (long size: partitionSizesInMergedFile) {
+ sumOfPartitionSizes += size;
+ }
+ assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
+
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndLZF() throws Exception {
+ testMergingSpills(true, LZFCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndLZF() throws Exception {
+ testMergingSpills(false, LZFCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndLZ4() throws Exception {
+ testMergingSpills(true, LZ4CompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
+ testMergingSpills(false, LZ4CompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndSnappy() throws Exception {
+ testMergingSpills(true, SnappyCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
+ testMergingSpills(false, SnappyCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
+ testMergingSpills(true, null);
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
+ testMergingSpills(false, null);
+ }
+
+ @Test
+ public void writeEnoughDataToTriggerSpill() throws Exception {
+ when(shuffleMemoryManager.tryToAcquire(anyLong()))
+ .then(returnsFirstArg()) // Allocate initial sort buffer
+ .then(returnsFirstArg()) // Allocate initial data page
+ .thenReturn(0L) // Deny request to allocate new data page
+ .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+ final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
+ for (int i = 0; i < 128 + 1; i++) {
+ dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
+ }
+ writer.write(dataToWrite.iterator());
+ verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ assertEquals(2, spillFilesCreated.size());
+ writer.stop(true);
+ readRecordsFromFile();
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+ when(shuffleMemoryManager.tryToAcquire(anyLong()))
+ .then(returnsFirstArg()) // Allocate initial sort buffer
+ .then(returnsFirstArg()) // Allocate initial data page
+ .thenReturn(0L) // Deny request to grow sort buffer
+ .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+ for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
+ dataToWrite.add(new Tuple2<Object, Object>(i, i));
+ }
+ writer.write(dataToWrite.iterator());
+ verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ assertEquals(2, spillFilesCreated.size());
+ writer.stop(true);
+ readRecordsFromFile();
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ final ArrayList<Product2<Object, Object>> dataToWrite =
+ new ArrayList<Product2<Object, Object>>();
+ final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+ new Random(42).nextBytes(bytes);
+ dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
+ writer.write(dataToWrite.iterator());
+ writer.stop(true);
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
+ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
+ // Use a custom serializer so that we have exact control over the size of serialized data.
+ final Serializer byteArraySerializer = new Serializer() {
+ @Override
+ public SerializerInstance newInstance() {
+ return new SerializerInstance() {
+ @Override
+ public SerializationStream serializeStream(final OutputStream s) {
+ return new SerializationStream() {
+ @Override
+ public void flush() { }
+
+ @Override
+ public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
+ byte[] bytes = (byte[]) t;
+ try {
+ s.write(bytes);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return this;
+ }
+
+ @Override
+ public void close() { }
+ };
+ }
+ public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) { return null; }
+ public DeserializationStream deserializeStream(InputStream s) { return null; }
+ public <T> T deserialize(ByteBuffer b, ClassLoader l, ClassTag<T> ev1) { return null; }
+ public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) { return null; }
+ };
+ }
+ };
+ when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(byteArraySerializer));
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ // Insert a record and force a spill so that there's something to clean up:
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[1], new byte[1]));
+ writer.forceSorterToSpill();
+ // We should be able to write a record that's right _at_ the max record size
+ final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE];
+ new Random(42).nextBytes(atMaxRecordSize);
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[0], atMaxRecordSize));
+ writer.forceSorterToSpill();
+ // Inserting a record that's larger than the max record size should fail:
+ final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1];
+ new Random(42).nextBytes(exceedsMaxRecordSize);
+ Product2<Object, Object> hugeRecord =
+ new Tuple2<Object, Object>(new byte[0], exceedsMaxRecordSize);
+ try {
+ // Here, we write through the public `write()` interface instead of the test-only
+ // `insertRecordIntoSorter` interface:
+ writer.write(Collections.singletonList(hugeRecord).iterator());
+ fail("Expected exception to be thrown");
+ } catch (IOException e) {
+ // Pass
+ }
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
+ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+ writer.forceSorterToSpill();
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+ writer.stop(false);
+ assertSpillFilesWereCleanedUp();
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index 8c6035fb36..cf6a143537 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.io
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import com.google.common.io.ByteStreams
import org.scalatest.FunSuite
import org.apache.spark.SparkConf
@@ -62,6 +63,14 @@ class CompressionCodecSuite extends FunSuite {
testCodec(codec)
}
+ test("lz4 does not support concatenation of serialized streams") {
+ val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName)
+ assert(codec.getClass === classOf[LZ4CompressionCodec])
+ intercept[Exception] {
+ testConcatenationOfSerializedStreams(codec)
+ }
+ }
+
test("lzf compression codec") {
val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
assert(codec.getClass === classOf[LZFCompressionCodec])
@@ -74,6 +83,12 @@ class CompressionCodecSuite extends FunSuite {
testCodec(codec)
}
+ test("lzf supports concatenation of serialized streams") {
+ val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
+ assert(codec.getClass === classOf[LZFCompressionCodec])
+ testConcatenationOfSerializedStreams(codec)
+ }
+
test("snappy compression codec") {
val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
assert(codec.getClass === classOf[SnappyCompressionCodec])
@@ -86,9 +101,38 @@ class CompressionCodecSuite extends FunSuite {
testCodec(codec)
}
+ test("snappy does not support concatenation of serialized streams") {
+ val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
+ assert(codec.getClass === classOf[SnappyCompressionCodec])
+ intercept[Exception] {
+ testConcatenationOfSerializedStreams(codec)
+ }
+ }
+
test("bad compression codec") {
intercept[IllegalArgumentException] {
CompressionCodec.createCodec(conf, "foobar")
}
}
+
+ private def testConcatenationOfSerializedStreams(codec: CompressionCodec): Unit = {
+ val bytes1: Array[Byte] = {
+ val baos = new ByteArrayOutputStream()
+ val out = codec.compressedOutputStream(baos)
+ (0 to 64).foreach(out.write)
+ out.close()
+ baos.toByteArray
+ }
+ val bytes2: Array[Byte] = {
+ val baos = new ByteArrayOutputStream()
+ val out = codec.compressedOutputStream(baos)
+ (65 to 127).foreach(out.write)
+ out.close()
+ baos.toByteArray
+ }
+ val concatenatedBytes = codec.compressedInputStream(new ByteArrayInputStream(bytes1 ++ bytes2))
+ val decompressed: Array[Byte] = new Array[Byte](128)
+ ByteStreams.readFully(concatenatedBytes, decompressed)
+ assert(decompressed.toSeq === (0 to 127))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
new file mode 100644
index 0000000000..ed4d8ce632
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.spark.SparkConf
+import org.scalatest.FunSuite
+
+class JavaSerializerSuite extends FunSuite {
+ test("JavaSerializer instances are serializable") {
+ val serializer = new JavaSerializer(new SparkConf())
+ val instance = serializer.newInstance()
+ instance.deserialize[JavaSerializer](instance.serialize(serializer))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
new file mode 100644
index 0000000000..49a04a2a45
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
+
+/**
+ * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
+ * performed in other suites.
+ */
+class UnsafeShuffleManagerSuite extends FunSuite with Matchers {
+
+ import UnsafeShuffleManager.canUseUnsafeShuffle
+
+ private class RuntimeExceptionAnswer extends Answer[Object] {
+ override def answer(invocation: InvocationOnMock): Object = {
+ throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
+ }
+ }
+
+ private def shuffleDep(
+ partitioner: Partitioner,
+ serializer: Option[Serializer],
+ keyOrdering: Option[Ordering[Any]],
+ aggregator: Option[Aggregator[Any, Any, Any]],
+ mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
+ val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
+ doReturn(0).when(dep).shuffleId
+ doReturn(partitioner).when(dep).partitioner
+ doReturn(serializer).when(dep).serializer
+ doReturn(keyOrdering).when(dep).keyOrdering
+ doReturn(aggregator).when(dep).aggregator
+ doReturn(mapSideCombine).when(dep).mapSideCombine
+ dep
+ }
+
+ test("supported shuffle dependencies") {
+ val kryo = Some(new KryoSerializer(new SparkConf()))
+
+ assert(canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
+ when(rangePartitioner.numPartitions).thenReturn(2)
+ assert(canUseUnsafeShuffle(shuffleDep(
+ partitioner = rangePartitioner,
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ }
+
+ test("unsupported shuffle dependencies") {
+ val kryo = Some(new KryoSerializer(new SparkConf()))
+ val java = Some(new JavaSerializer(new SparkConf()))
+
+ // We only support serializers that support object relocation
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = java,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ // We do not support shuffles with more than 16 million output partitions
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ // We do not support shuffles that perform any kind of aggregation or sorting of keys
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = Some(mock(classOf[Ordering[Any]])),
+ aggregator = None,
+ mapSideCombine = false
+ )))
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+ mapSideCombine = false
+ )))
+ // We do not support shuffles that perform any kind of aggregation or sorting of keys
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = Some(mock(classOf[Ordering[Any]])),
+ aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+ mapSideCombine = true
+ )))
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
new file mode 100644
index 0000000000..6351539e91
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
+class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
+
+ // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
+
+ override def beforeAll() {
+ conf.set("spark.shuffle.manager", "tungsten-sort")
+ // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort
+ // shuffle records.
+ conf.set("spark.shuffle.memoryFraction", "0.5")
+ }
+
+ test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
+ val tmpDir = Utils.createTempDir()
+ try {
+ val myConf = conf.clone()
+ .set("spark.local.dir", tmpDir.getAbsolutePath)
+ sc = new SparkContext("local", "test", myConf)
+ // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new KryoSerializer(myConf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+ def getAllFiles: Set[File] =
+ FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+ val filesBeforeShuffle = getAllFiles
+ // Force the shuffle to be performed
+ shuffledRdd.count()
+ // Ensure that the shuffle actually created files that will need to be cleaned up
+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+ filesCreatedByShuffle.map(_.getName) should be
+ Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+ // Check that the cleanup actually removes the files
+ sc.env.blockManager.master.removeShuffle(0, blocking = true)
+ for (file <- filesCreatedByShuffle) {
+ assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+ }
+ } finally {
+ Utils.deleteRecursively(tmpDir)
+ }
+ }
+
+ test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
+ val tmpDir = Utils.createTempDir()
+ try {
+ val myConf = conf.clone()
+ .set("spark.local.dir", tmpDir.getAbsolutePath)
+ sc = new SparkContext("local", "test", myConf)
+ // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new JavaSerializer(myConf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+ def getAllFiles: Set[File] =
+ FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+ val filesBeforeShuffle = getAllFiles
+ // Force the shuffle to be performed
+ shuffledRdd.count()
+ // Ensure that the shuffle actually created files that will need to be cleaned up
+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+ filesCreatedByShuffle.map(_.getName) should be
+ Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+ // Check that the cleanup actually removes the files
+ sc.env.blockManager.master.removeShuffle(0, blocking = true)
+ for (file <- filesCreatedByShuffle) {
+ assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+ }
+ } finally {
+ Utils.deleteRecursively(tmpDir)
+ }
+ }
+}
diff --git a/pom.xml b/pom.xml
index cf9279ea5a..564a443466 100644
--- a/pom.xml
+++ b/pom.xml
@@ -669,7 +669,7 @@
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
- <version>1.9.0</version>
+ <version>1.9.5</version>
<scope>test</scope>
</dependency>
<dependency>
@@ -685,6 +685,18 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-core</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-library</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>
<version>0.10</version>
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index fba7290dcb..487062a31f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -131,6 +131,12 @@ object MimaExcludes {
// SPARK-7530 Added StreamingContext.getState()
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.StreamingContext.state_=")
+ ) ++ Seq(
+ // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some
+ // unnecessary type bounds in order to fix some compiler warnings that occurred when
+ // implementing this interface in Java. Note that ShuffleWriter is private[spark].
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+ "org.apache.spark.shuffle.ShuffleWriter")
)
case v if v.startsWith("1.3") =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index c3d2c7019a..3e46596ecf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -17,17 +17,18 @@
package org.apache.spark.sql.execution
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.util.MutablePair
object Exchange {
@@ -85,7 +86,9 @@ case class Exchange(
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangePartitioner, for example).
val conf = child.sqlContext.sparkContext.conf
- val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+ val shuffleManager = SparkEnv.get.shuffleManager
+ val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
+ shuffleManager.isInstanceOf[UnsafeShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
if (newOrdering.nonEmpty) {
@@ -93,11 +96,11 @@ case class Exchange(
// which requires a defensive copy.
true
} else if (sortBasedShuffleOn) {
- // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
- // However, there are two special cases where we can avoid the copy, described below:
- if (partitioner.numPartitions <= bypassMergeThreshold) {
- // If the number of output partitions is sufficiently small, then Spark will fall back to
- // the old hash-based shuffle write path which doesn't buffer deserialized records.
+ val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+ if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
+ // If we're using the original SortShuffleManager and the number of output partitions is
+ // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
+ // doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
false
} else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
@@ -105,9 +108,14 @@ case class Exchange(
// them. This optimization is guarded by a feature-flag and is only applied in cases where
// shuffle dependency does not specify an ordering and the record serializer has certain
// properties. If this optimization is enabled, we can safely avoid the copy.
+ //
+ // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
false
} else {
- // None of the special cases held, so we must copy.
+ // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
+ // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
+ // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
+ // both cases, we must copy.
true
}
} else {
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 5b0733206b..9e151fc7a9 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -42,6 +42,10 @@
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
<!-- Provided dependencies -->
<dependency>
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 9224988e6a..2906ac8aba 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -19,6 +19,7 @@ package org.apache.spark.unsafe.memory;
import java.util.*;
+import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -47,10 +48,18 @@ public final class TaskMemoryManager {
private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
- /**
- * The number of entries in the page table.
- */
- private static final int PAGE_TABLE_SIZE = 1 << 13;
+ /** The number of bits used to address the page table. */
+ private static final int PAGE_NUMBER_BITS = 13;
+
+ /** The number of bits used to encode offsets in data pages. */
+ @VisibleForTesting
+ static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51
+
+ /** The number of entries in the page table. */
+ private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
+
+ /** Maximum supported data page size */
+ private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS);
/** Bit mask for the lower 51 bits of a long. */
private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
@@ -101,11 +110,9 @@ public final class TaskMemoryManager {
* intended for allocating large blocks of memory that will be shared between operators.
*/
public MemoryBlock allocatePage(long size) {
- if (logger.isTraceEnabled()) {
- logger.trace("Allocating {} byte page", size);
- }
- if (size >= (1L << 51)) {
- throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes");
+ if (size > MAXIMUM_PAGE_SIZE) {
+ throw new IllegalArgumentException(
+ "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes");
}
final int pageNumber;
@@ -120,8 +127,8 @@ public final class TaskMemoryManager {
final MemoryBlock page = executorMemoryManager.allocate(size);
page.pageNumber = pageNumber;
pageTable[pageNumber] = page;
- if (logger.isDebugEnabled()) {
- logger.debug("Allocate page number {} ({} bytes)", pageNumber, size);
+ if (logger.isTraceEnabled()) {
+ logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
}
return page;
}
@@ -130,9 +137,6 @@ public final class TaskMemoryManager {
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
*/
public void freePage(MemoryBlock page) {
- if (logger.isTraceEnabled()) {
- logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size());
- }
assert (page.pageNumber != -1) :
"Called freePage() on memory that wasn't allocated with allocatePage()";
executorMemoryManager.free(page);
@@ -140,8 +144,8 @@ public final class TaskMemoryManager {
allocatedPages.clear(page.pageNumber);
}
pageTable[page.pageNumber] = null;
- if (logger.isDebugEnabled()) {
- logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size());
+ if (logger.isTraceEnabled()) {
+ logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
}
}
@@ -173,14 +177,36 @@ public final class TaskMemoryManager {
/**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
+ *
+ * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
+ * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
+ * this should be the value that you would pass as the base offset into an
+ * UNSAFE call (e.g. page.baseOffset() + something).
+ * @return an encoded page address.
*/
public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
- if (inHeap) {
- assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
- return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
- } else {
- return offsetInPage;
+ if (!inHeap) {
+ // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
+ // encode. Due to our page size limitation, though, we can convert this into an offset that's
+ // relative to the page's base offset; this relative offset will fit in 51 bits.
+ offsetInPage -= page.getBaseOffset();
}
+ return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
+ }
+
+ @VisibleForTesting
+ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+ assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+ return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+ }
+
+ @VisibleForTesting
+ public static int decodePageNumber(long pagePlusOffsetAddress) {
+ return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS);
+ }
+
+ private static long decodeOffset(long pagePlusOffsetAddress) {
+ return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
}
/**
@@ -189,7 +215,7 @@ public final class TaskMemoryManager {
*/
public Object getPage(long pagePlusOffsetAddress) {
if (inHeap) {
- final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51);
+ final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
final Object page = pageTable[pageNumber].getBaseObject();
assert (page != null);
@@ -204,10 +230,15 @@ public final class TaskMemoryManager {
* {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
*/
public long getOffsetInPage(long pagePlusOffsetAddress) {
+ final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
if (inHeap) {
- return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+ return offsetInPage;
} else {
- return pagePlusOffsetAddress;
+ // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
+ // converted the absolute address into a relative address. Here, we invert that operation:
+ final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+ assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+ return pageTable[pageNumber].getBaseOffset() + offsetInPage;
}
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
index 932882f1ca..06fb081183 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
@@ -38,4 +38,27 @@ public class TaskMemoryManagerSuite {
Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
}
+ @Test
+ public void encodePageNumberAndOffsetOffHeap() {
+ final TaskMemoryManager manager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+ final MemoryBlock dataPage = manager.allocatePage(256);
+ // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
+ // encode. This test exercises that corner-case:
+ final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
+ final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
+ Assert.assertEquals(null, manager.getPage(encodedAddress));
+ Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
+ }
+
+ @Test
+ public void encodePageNumberAndOffsetOnHeap() {
+ final TaskMemoryManager manager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = manager.allocatePage(256);
+ final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
+ Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
+ Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
+ }
+
}