aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-22 09:46:30 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-10-22 09:46:30 -0700
commitf6d06adf05afa9c5386dc2396c94e7a98730289f (patch)
tree9e3d8e4350e0a465124840eea91f6aa39c00b156 /core
parent94e2064fa1b04c05c805d9175c7c78bf583db5c6 (diff)
downloadspark-f6d06adf05afa9c5386dc2396c94e7a98730289f.tar.gz
spark-f6d06adf05afa9c5386dc2396c94e7a98730289f.tar.bz2
spark-f6d06adf05afa9c5386dc2396c94e7a98730289f.zip
[SPARK-10708] Consolidate sort shuffle implementations
There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together. Author: Josh Rosen <joshrosen@databricks.com> Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java106
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java (renamed from core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java)2
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java (renamed from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java)28
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java (renamed from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java)16
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java (renamed from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java)8
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java53
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java (renamed from core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java)4
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java (renamed from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java)12
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala175
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala202
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala146
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala273
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java (renamed from core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java)5
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java (renamed from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java)16
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java (renamed from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java)10
-rw-r--r--core/src/test/scala/org/apache/spark/SortShuffleSuite.scala65
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala64
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala (renamed from core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala)30
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala45
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala102
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala144
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala148
26 files changed, 435 insertions, 1290 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index f5d80bbcf3..ee82d67993 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -21,21 +21,30 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
+import javax.annotation.Nullable;
+import scala.None$;
+import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -62,7 +71,7 @@ import org.apache.spark.util.Utils;
* <p>
* There have been proposals to completely remove this code path; see SPARK-6026 for details.
*/
-final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
+final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
@@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
private final BlockManager blockManager;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
+ private final int shuffleId;
+ private final int mapId;
private final Serializer serializer;
+ private final IndexShuffleBlockResolver shuffleBlockResolver;
/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
+ @Nullable private MapStatus mapStatus;
+ private long[] partitionLengths;
+
+ /**
+ * Are we in the process of stopping? Because map tasks can call stop() with success = true
+ * and then call stop() with success = false if they get an exception, we want to make sure
+ * we don't try deleting files, etc twice.
+ */
+ private boolean stopping = false;
public BypassMergeSortShuffleWriter(
- SparkConf conf,
BlockManager blockManager,
- Partitioner partitioner,
- ShuffleWriteMetrics writeMetrics,
- Serializer serializer) {
+ IndexShuffleBlockResolver shuffleBlockResolver,
+ BypassMergeSortShuffleHandle<K, V> handle,
+ int mapId,
+ TaskContext taskContext,
+ SparkConf conf) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
- this.numPartitions = partitioner.numPartitions();
this.blockManager = blockManager;
- this.partitioner = partitioner;
- this.writeMetrics = writeMetrics;
- this.serializer = serializer;
+ final ShuffleDependency<K, V, V> dep = handle.dependency();
+ this.mapId = mapId;
+ this.shuffleId = dep.shuffleId();
+ this.partitioner = dep.partitioner();
+ this.numPartitions = partitioner.numPartitions();
+ this.writeMetrics = new ShuffleWriteMetrics();
+ taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+ this.serializer = Serializer.getSerializer(dep.serializer());
+ this.shuffleBlockResolver = shuffleBlockResolver;
}
@Override
- public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
+ public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
+ partitionLengths = new long[numPartitions];
+ shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@@ -124,13 +154,24 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}
+
+ partitionLengths =
+ writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
+ shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
- @Override
- public long[] writePartitionedFile(
- BlockId blockId,
- TaskContext context,
- File outputFile) throws IOException {
+ @VisibleForTesting
+ long[] getPartitionLengths() {
+ return partitionLengths;
+ }
+
+ /**
+ * Concatenate all of the per-partition files into a single combined file.
+ *
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
+ */
+ private long[] writePartitionedFile(File outputFile) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
@@ -165,18 +206,33 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
}
@Override
- public void stop() throws IOException {
- if (partitionWriters != null) {
- try {
- for (DiskBlockObjectWriter writer : partitionWriters) {
- // This method explicitly does _not_ throw exceptions:
- File file = writer.revertPartialWritesAndClose();
- if (!file.delete()) {
- logger.error("Error while deleting file {}", file.getAbsolutePath());
+ public Option<MapStatus> stop(boolean success) {
+ if (stopping) {
+ return None$.empty();
+ } else {
+ stopping = true;
+ if (success) {
+ if (mapStatus == null) {
+ throw new IllegalStateException("Cannot call stop(true) without having called write()");
+ }
+ return Option.apply(mapStatus);
+ } else {
+ // The map task failed, so delete our output data.
+ if (partitionWriters != null) {
+ try {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
+ // This method explicitly does _not_ throw exceptions:
+ File file = writer.revertPartialWritesAndClose();
+ if (!file.delete()) {
+ logger.error("Error while deleting file {}", file.getAbsolutePath());
+ }
+ }
+ } finally {
+ partitionWriters = null;
}
}
- } finally {
- partitionWriters = null;
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+ return None$.empty();
}
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
index 4ee6a82c04..c11711966f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
/**
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index e73ba39468..85fdaa8115 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
import javax.annotation.Nullable;
import java.io.File;
@@ -48,7 +48,7 @@ import org.apache.spark.util.Utils;
* <p>
* Incoming records are appended to data pages. When all records have been inserted (or when the
* current thread's shuffle memory limit is reached), the in-memory records are sorted according to
- * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
+ * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then
* written to a single output file (or multiple files, if we've spilled). The format of the output
* files is the same as the format of the final output file written by
* {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
@@ -59,9 +59,9 @@ import org.apache.spark.util.Utils;
* spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
* specialized merge procedure that avoids extra serialization/deserialization.
*/
-final class UnsafeShuffleExternalSorter {
+final class ShuffleExternalSorter {
- private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
+ private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
@@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter {
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
+ private long numRecordsInsertedSinceLastSpill = 0;
+
+ /** Force this sorter to spill when there are this many elements in memory. For testing only */
+ private final long numElementsForSpillThreshold;
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
private final int fileBufferSizeBytes;
@@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter {
private long peakMemoryUsedBytes;
// These variables are reset after spilling:
- @Nullable private UnsafeShuffleInMemorySorter inMemSorter;
+ @Nullable private ShuffleInMemorySorter inMemSorter;
@Nullable private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
- public UnsafeShuffleExternalSorter(
+ public ShuffleExternalSorter(
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
@@ -117,6 +121,8 @@ final class UnsafeShuffleExternalSorter {
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.numElementsForSpillThreshold =
+ conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.pageSizeBytes = (int) Math.min(
PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
this.maxRecordSizeBytes = pageSizeBytes - 4;
@@ -140,7 +146,8 @@ final class UnsafeShuffleExternalSorter {
throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
}
- this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
+ this.inMemSorter = new ShuffleInMemorySorter(initialSize);
+ numRecordsInsertedSinceLastSpill = 0;
}
/**
@@ -166,7 +173,7 @@ final class UnsafeShuffleExternalSorter {
}
// This call performs the actual sort.
- final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
+ final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();
// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
@@ -406,6 +413,10 @@ final class UnsafeShuffleExternalSorter {
int lengthInBytes,
int partitionId) throws IOException {
+ if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
+ spill();
+ }
+
growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int totalSpaceRequired = lengthInBytes + 4;
@@ -453,6 +464,7 @@ final class UnsafeShuffleExternalSorter {
recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
assert(inMemSorter != null);
inMemSorter.insertRecord(recordAddress, partitionId);
+ numRecordsInsertedSinceLastSpill += 1;
}
/**
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 5bab501da9..a8dee6c610 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
import java.util.Comparator;
import org.apache.spark.util.collection.Sorter;
-final class UnsafeShuffleInMemorySorter {
+final class ShuffleInMemorySorter {
private final Sorter<PackedRecordPointer, long[]> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> {
@@ -44,10 +44,10 @@ final class UnsafeShuffleInMemorySorter {
*/
private int pointerArrayInsertPosition = 0;
- public UnsafeShuffleInMemorySorter(int initialSize) {
+ public ShuffleInMemorySorter(int initialSize) {
assert (initialSize > 0);
this.pointerArray = new long[initialSize];
- this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
+ this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
}
public void expandPointerArray() {
@@ -92,14 +92,14 @@ final class UnsafeShuffleInMemorySorter {
/**
* An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
*/
- public static final class UnsafeShuffleSorterIterator {
+ public static final class ShuffleSorterIterator {
private final long[] pointerArray;
private final int numRecords;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;
- public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
+ public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
this.numRecords = numRecords;
this.pointerArray = pointerArray;
}
@@ -117,8 +117,8 @@ final class UnsafeShuffleInMemorySorter {
/**
* Return an iterator over record pointers in sorted order.
*/
- public UnsafeShuffleSorterIterator getSortedIterator() {
+ public ShuffleSorterIterator getSortedIterator() {
sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
- return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+ return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
index a66d74ee44..8a1e5aec6f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
@@ -15,15 +15,15 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
import org.apache.spark.util.collection.SortDataFormat;
-final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
+final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
- public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
+ public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
- private UnsafeShuffleSortDataFormat() { }
+ private ShuffleSortDataFormat() { }
@Override
public PackedRecordPointer getKey(long[] data, int pos) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
deleted file mode 100644
index 656ea0401a..0000000000
--- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort;
-
-import java.io.File;
-import java.io.IOException;
-
-import scala.Product2;
-import scala.collection.Iterator;
-
-import org.apache.spark.annotation.Private;
-import org.apache.spark.TaskContext;
-import org.apache.spark.storage.BlockId;
-
-/**
- * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
- */
-@Private
-public interface SortShuffleFileWriter<K, V> {
-
- void insertAll(Iterator<Product2<K, V>> records) throws IOException;
-
- /**
- * Write all the data added into this shuffle sorter into a file in the disk store. This is
- * called by the SortShuffleWriter and can go through an efficient path of just concatenating
- * binary files if we decided to avoid merge-sorting.
- *
- * @param blockId block ID to write to. The index file will be blockId.name + ".index".
- * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
- * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
- */
- long[] writePartitionedFile(
- BlockId blockId,
- TaskContext context,
- File outputFile) throws IOException;
-
- void stop() throws IOException;
-}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
index 7bac0dc0bb..df9f7b7abe 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
@@ -15,14 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
import java.io.File;
import org.apache.spark.storage.TempShuffleBlockId;
/**
- * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
+ * Metadata for a block of data written by {@link ShuffleExternalSorter}.
*/
final class SpillInfo {
final long[] partitionLengths;
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index fdb309e365..e8f050cb2d 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
import javax.annotation.Nullable;
import java.io.*;
@@ -80,7 +80,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final boolean transferToEnabled;
@Nullable private MapStatus mapStatus;
- @Nullable private UnsafeShuffleExternalSorter sorter;
+ @Nullable private ShuffleExternalSorter sorter;
private long peakMemoryUsedBytes = 0;
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
@@ -104,15 +104,15 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
IndexShuffleBlockResolver shuffleBlockResolver,
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
- UnsafeShuffleHandle<K, V> handle,
+ SerializedShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf sparkConf) throws IOException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
- if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
+ if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
"UnsafeShuffleWriter can only be used for shuffles with at most " +
- UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
+ SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions");
}
this.blockManager = blockManager;
this.shuffleBlockResolver = shuffleBlockResolver;
@@ -195,7 +195,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private void open() throws IOException {
assert (sorter == null);
- sorter = new UnsafeShuffleExternalSorter(
+ sorter = new ShuffleExternalSorter(
memoryManager,
shuffleMemoryManager,
blockManager,
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index c329983451..704158bfc7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -330,7 +330,7 @@ object SparkEnv extends Logging {
val shortShuffleMgrNames = Map(
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
"sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
- "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
+ "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 9df4e55166..1105167d39 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -19,9 +19,53 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap
-import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency}
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
+/**
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * Sort-based shuffle has two different write paths for producing its map output files:
+ *
+ * - Serialized sorting: used when all three of the following conditions hold:
+ * 1. The shuffle dependency specifies no aggregation or output ordering.
+ * 2. The shuffle serializer supports relocation of serialized values (this is currently
+ * supported by KryoSerializer and Spark SQL's custom serializers).
+ * 3. The shuffle produces fewer than 16777216 output partitions.
+ * - Deserialized sorting: used to handle all other cases.
+ *
+ * -----------------------
+ * Serialized sorting mode
+ * -----------------------
+ *
+ * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the
+ * shuffle writer and are buffered in a serialized form during sorting. This write path implements
+ * several optimizations:
+ *
+ * - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ * consumption and GC overheads. This optimization requires the record serializer to have certain
+ * properties to allow serialized records to be re-ordered without requiring deserialization.
+ * See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ * - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts
+ * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ * record in the sorting array, this fits more of the array into cache.
+ *
+ * - The spill merging procedure operates on blocks of serialized records that belong to the same
+ * partition and does not need to deserialize records during the merge.
+ *
+ * - When the spill compression codec supports concatenation of compressed data, the spill merge
+ * simply concatenates the serialized and compressed spill partitions to produce the final output
+ * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ * and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on these optimizations, see SPARK-7081.
+ */
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
if (!conf.getBoolean("spark.shuffle.spill", true)) {
@@ -30,8 +74,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
" Shuffle will continue to spill to disk when necessary.")
}
- private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf)
- private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
+ /**
+ * A mapping from shuffle ids to the number of mappers producing output for those shuffles.
+ */
+ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
+
+ override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
@@ -40,7 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
- new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
+ // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+ // need map-side aggregation, then write numPartitions files directly and just concatenate
+ // them at the end. This avoids doing serialization and deserialization twice to merge
+ // together the spilled files, which would happen with the normal code path. The downside is
+ // having multiple files open at a time and thus more memory allocated to buffers.
+ new BypassMergeSortShuffleHandle[K, V](
+ shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
+ // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
+ new SerializedShuffleHandle[K, V](
+ shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else {
+ // Otherwise, buffer map outputs in a deserialized form:
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
}
/**
@@ -52,38 +115,114 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
- // We currently use the same block store shuffle fetcher as the hash-based shuffle.
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
/** Get a writer for a given partition. Called on executors by map tasks. */
- override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
- : ShuffleWriter[K, V] = {
- val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
- shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
- new SortShuffleWriter(
- shuffleBlockResolver, baseShuffleHandle, mapId, context)
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Int,
+ context: TaskContext): ShuffleWriter[K, V] = {
+ numMapsForShuffle.putIfAbsent(
+ handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
+ val env = SparkEnv.get
+ handle match {
+ case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
+ new UnsafeShuffleWriter(
+ env.blockManager,
+ shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+ context.taskMemoryManager(),
+ env.shuffleMemoryManager,
+ unsafeShuffleHandle,
+ mapId,
+ context,
+ env.conf)
+ case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
+ new BypassMergeSortShuffleWriter(
+ env.blockManager,
+ shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+ bypassMergeSortHandle,
+ mapId,
+ context,
+ env.conf)
+ case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
+ new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
+ }
}
/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
- if (shuffleMapNumber.containsKey(shuffleId)) {
- val numMaps = shuffleMapNumber.remove(shuffleId)
- (0 until numMaps).map{ mapId =>
+ Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps =>
+ (0 until numMaps).foreach { mapId =>
shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
}
- override val shuffleBlockResolver: IndexShuffleBlockResolver = {
- indexShuffleBlockResolver
- }
-
/** Shut down this ShuffleManager. */
override def stop(): Unit = {
shuffleBlockResolver.stop()
}
}
+
+private[spark] object SortShuffleManager extends Logging {
+
+ /**
+ * The maximum number of shuffle output partitions that SortShuffleManager supports when
+ * buffering map outputs in a serialized form. This is an extreme defensive programming measure,
+ * since it's extremely unlikely that a single shuffle produces over 16 million output partitions.
+ * */
+ val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE =
+ PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+ /**
+ * Helper method for determining whether a shuffle should use an optimized serialized shuffle
+ * path or whether it should fall back to the original path that operates on deserialized objects.
+ */
+ def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
+ val shufId = dependency.shuffleId
+ val numPartitions = dependency.partitioner.numPartitions
+ val serializer = Serializer.getSerializer(dependency.serializer)
+ if (!serializer.supportsRelocationOfSerializedObjects) {
+ log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
+ s"${serializer.getClass.getName}, does not support object relocation")
+ false
+ } else if (dependency.aggregator.isDefined) {
+ log.debug(
+ s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined")
+ false
+ } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
+ log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
+ s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
+ false
+ } else {
+ log.debug(s"Can use serialized shuffle for shuffle $shufId")
+ true
+ }
+ }
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * serialized shuffle.
+ */
+private[spark] class SerializedShuffleHandle[K, V](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * bypass merge sort shuffle path.
+ */
+private[spark] class BypassMergeSortShuffleHandle[K, V](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 5865e7640c..bbd9c1ab53 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
@@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
- private var sorter: SortShuffleFileWriter[K, V] = null
+ private var sorter: ExternalSorter[K, V, _] = null
// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
@@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C](
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
- } else if (SortShuffleWriter.shouldBypassMergeSort(
- SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
- // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
- // need local aggregation and sorting, write numPartitions files directly and just concatenate
- // them at the end. This avoids doing serialization and deserialization twice to merge
- // together the spilled files, which would happen with the normal code path. The downside is
- // having multiple files open at a time and thus more memory allocated to buffers.
- new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
- writeMetrics, Serializer.getSerializer(dep.serializer))
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
@@ -111,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C](
}
private[spark] object SortShuffleWriter {
- def shouldBypassMergeSort(
- conf: SparkConf,
- numPartitions: Int,
- aggregator: Option[Aggregator[_, _, _]],
- keyOrdering: Option[Ordering[_]]): Boolean = {
- val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
+ def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
+ // We cannot bypass sorting if we need to do map-side aggregation.
+ if (dep.mapSideCombine) {
+ require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
+ false
+ } else {
+ val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ dep.partitioner.numPartitions <= bypassMergeThreshold
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
deleted file mode 100644
index 75f22f642b..0000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
+++ /dev/null
@@ -1,202 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import java.util.Collections
-import java.util.concurrent.ConcurrentHashMap
-
-import org.apache.spark._
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle._
-import org.apache.spark.shuffle.sort.SortShuffleManager
-
-/**
- * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
- */
-private[spark] class UnsafeShuffleHandle[K, V](
- shuffleId: Int,
- numMaps: Int,
- dependency: ShuffleDependency[K, V, V])
- extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
-}
-
-private[spark] object UnsafeShuffleManager extends Logging {
-
- /**
- * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
- */
- val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
-
- /**
- * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
- * path or whether it should fall back to the original sort-based shuffle.
- */
- def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
- val shufId = dependency.shuffleId
- val serializer = Serializer.getSerializer(dependency.serializer)
- if (!serializer.supportsRelocationOfSerializedObjects) {
- log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
- s"${serializer.getClass.getName}, does not support object relocation")
- false
- } else if (dependency.aggregator.isDefined) {
- log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
- false
- } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
- log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
- s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
- false
- } else {
- log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
- true
- }
- }
-}
-
-/**
- * A shuffle implementation that uses directly-managed memory to implement several performance
- * optimizations for certain types of shuffles. In cases where the new performance optimizations
- * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
- * shuffles.
- *
- * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
- *
- * - The shuffle dependency specifies no aggregation or output ordering.
- * - The shuffle serializer supports relocation of serialized values (this is currently supported
- * by KryoSerializer and Spark SQL's custom serializers).
- * - The shuffle produces fewer than 16777216 output partitions.
- * - No individual record is larger than 128 MB when serialized.
- *
- * In addition, extra spill-merging optimizations are automatically applied when the shuffle
- * compression codec supports concatenation of serialized streams. This is currently supported by
- * Spark's LZF serializer.
- *
- * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
- * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
- * written to a single map output file. Reducers fetch contiguous regions of this file in order to
- * read their portion of the map output. In cases where the map output data is too large to fit in
- * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
- * to produce the final output file.
- *
- * UnsafeShuffleManager optimizes this process in several ways:
- *
- * - Its sort operates on serialized binary data rather than Java objects, which reduces memory
- * consumption and GC overheads. This optimization requires the record serializer to have certain
- * properties to allow serialized records to be re-ordered without requiring deserialization.
- * See SPARK-4550, where this optimization was first proposed and implemented, for more details.
- *
- * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
- * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
- * record in the sorting array, this fits more of the array into cache.
- *
- * - The spill merging procedure operates on blocks of serialized records that belong to the same
- * partition and does not need to deserialize records during the merge.
- *
- * - When the spill compression codec supports concatenation of compressed data, the spill merge
- * simply concatenates the serialized and compressed spill partitions to produce the final output
- * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
- * and avoids the need to allocate decompression or copying buffers during the merge.
- *
- * For more details on UnsafeShuffleManager's design, see SPARK-7081.
- */
-private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
-
- if (!conf.getBoolean("spark.shuffle.spill", true)) {
- logWarning(
- "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
- "manager; its optimized shuffles will continue to spill to disk when necessary.")
- }
-
- private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
- private[this] val shufflesThatFellBackToSortShuffle =
- Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
- private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
-
- /**
- * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
- */
- override def registerShuffle[K, V, C](
- shuffleId: Int,
- numMaps: Int,
- dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
- if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
- new UnsafeShuffleHandle[K, V](
- shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
- } else {
- new BaseShuffleHandle(shuffleId, numMaps, dependency)
- }
- }
-
- /**
- * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
- * Called on executors by reduce tasks.
- */
- override def getReader[K, C](
- handle: ShuffleHandle,
- startPartition: Int,
- endPartition: Int,
- context: TaskContext): ShuffleReader[K, C] = {
- sortShuffleManager.getReader(handle, startPartition, endPartition, context)
- }
-
- /** Get a writer for a given partition. Called on executors by map tasks. */
- override def getWriter[K, V](
- handle: ShuffleHandle,
- mapId: Int,
- context: TaskContext): ShuffleWriter[K, V] = {
- handle match {
- case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] =>
- numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
- val env = SparkEnv.get
- new UnsafeShuffleWriter(
- env.blockManager,
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
- context.taskMemoryManager(),
- env.shuffleMemoryManager,
- unsafeShuffleHandle,
- mapId,
- context,
- env.conf)
- case other =>
- shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
- sortShuffleManager.getWriter(handle, mapId, context)
- }
- }
-
- /** Remove a shuffle's metadata from the ShuffleManager. */
- override def unregisterShuffle(shuffleId: Int): Boolean = {
- if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
- sortShuffleManager.unregisterShuffle(shuffleId)
- } else {
- Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
- (0 until numMaps).foreach { mapId =>
- shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
- }
- }
- true
- }
- }
-
- override val shuffleBlockResolver: IndexShuffleBlockResolver = {
- sortShuffleManager.shuffleBlockResolver
- }
-
- /** Shut down this ShuffleManager. */
- override def stop(): Unit = {
- sortShuffleManager.stop()
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
deleted file mode 100644
index ae60f3b0cb..0000000000
--- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.OutputStream
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
- * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
- * of memory and needing to copy the full contents. The disadvantage is that the contents don't
- * occupy a contiguous segment of memory.
- */
-private[spark] class ChainedBuffer(chunkSize: Int) {
-
- private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros(
- java.lang.Long.highestOneBit(chunkSize))
- assert((1 << chunkSizeLog2) == chunkSize,
- s"ChainedBuffer chunk size $chunkSize must be a power of two")
- private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
- private var _size: Long = 0
-
- /**
- * Feed bytes from this buffer into a DiskBlockObjectWriter.
- *
- * @param pos Offset in the buffer to read from.
- * @param os OutputStream to read into.
- * @param len Number of bytes to read.
- */
- def read(pos: Long, os: OutputStream, len: Int): Unit = {
- if (pos + len > _size) {
- throw new IndexOutOfBoundsException(
- s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
- }
- var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
- var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
- var written: Int = 0
- while (written < len) {
- val toRead: Int = math.min(len - written, chunkSize - posInChunk)
- os.write(chunks(chunkIndex), posInChunk, toRead)
- written += toRead
- chunkIndex += 1
- posInChunk = 0
- }
- }
-
- /**
- * Read bytes from this buffer into a byte array.
- *
- * @param pos Offset in the buffer to read from.
- * @param bytes Byte array to read into.
- * @param offs Offset in the byte array to read to.
- * @param len Number of bytes to read.
- */
- def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
- if (pos + len > _size) {
- throw new IndexOutOfBoundsException(
- s"Read of $len bytes at position $pos would go past size of buffer")
- }
- var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
- var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
- var written: Int = 0
- while (written < len) {
- val toRead: Int = math.min(len - written, chunkSize - posInChunk)
- System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
- written += toRead
- chunkIndex += 1
- posInChunk = 0
- }
- }
-
- /**
- * Write bytes from a byte array into this buffer.
- *
- * @param pos Offset in the buffer to write to.
- * @param bytes Byte array to write from.
- * @param offs Offset in the byte array to write from.
- * @param len Number of bytes to write.
- */
- def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
- if (pos > _size) {
- throw new IndexOutOfBoundsException(
- s"Write at position $pos starts after end of buffer ${_size}")
- }
- // Grow if needed
- val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt
- while (endChunkIndex >= chunks.length) {
- chunks += new Array[Byte](chunkSize)
- }
-
- var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
- var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
- var written: Int = 0
- while (written < len) {
- val toWrite: Int = math.min(len - written, chunkSize - posInChunk)
- System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
- written += toWrite
- chunkIndex += 1
- posInChunk = 0
- }
-
- _size = math.max(_size, pos + len)
- }
-
- /**
- * Total size of buffer that can be written to without allocating additional memory.
- */
- def capacity: Long = chunks.size.toLong * chunkSize
-
- /**
- * Size of the logical buffer.
- */
- def size: Long = _size
-}
-
-/**
- * Output stream that writes to a ChainedBuffer.
- */
-private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
- private var pos: Long = 0
-
- override def write(b: Int): Unit = {
- throw new UnsupportedOperationException()
- }
-
- override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
- chainedBuffer.write(pos, bytes, offs, len)
- pos += len
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 749be34d8e..c48c453a90 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -29,7 +29,6 @@ import com.google.common.io.ByteStreams
import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
/**
@@ -69,8 +68,8 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
* At a high level, this class works internally as follows:
*
* - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
- * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we
- * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key.
+ * we want to combine by key, or a PartitionedPairBuffer if we don't.
+ * Inside these buffers, we sort elements by partition ID and then possibly also by key.
* To avoid calling the partitioner multiple times with each key, we store the partition ID
* alongside each record.
*
@@ -93,8 +92,7 @@ private[spark] class ExternalSorter[K, V, C](
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
extends Logging
- with Spillable[WritablePartitionedPairCollection[K, C]]
- with SortShuffleFileWriter[K, V] {
+ with Spillable[WritablePartitionedPairCollection[K, C]] {
private val conf = SparkEnv.get.conf
@@ -104,13 +102,6 @@ private[spark] class ExternalSorter[K, V, C](
if (shouldPartition) partitioner.get.getPartition(key) else 0
}
- // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
- // As a sanity check, make sure that we're not handling a shuffle which should use that path.
- if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
- throw new IllegalArgumentException("ExternalSorter should not be used to handle "
- + " a sort that the BypassMergeSortShuffleWriter should handle")
- }
-
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(serializer)
@@ -128,23 +119,11 @@ private[spark] class ExternalSorter[K, V, C](
// grow internal data structures by growing + copying every time the number of objects doubles.
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
- private val useSerializedPairBuffer =
- ordering.isEmpty &&
- conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
- ser.supportsRelocationOfSerializedObjects
- private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
- private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
- if (useSerializedPairBuffer) {
- new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
- } else {
- new PartitionedPairBuffer[K, C]
- }
- }
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
private var map = new PartitionedAppendOnlyMap[K, C]
- private var buffer = newBuffer()
+ private var buffer = new PartitionedPairBuffer[K, C]
// Total spilling statistics
private var _diskBytesSpilled = 0L
@@ -192,7 +171,7 @@ private[spark] class ExternalSorter[K, V, C](
*/
private[spark] def numSpills: Int = spills.size
- override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
+ def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
@@ -236,7 +215,7 @@ private[spark] class ExternalSorter[K, V, C](
} else {
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
- buffer = newBuffer()
+ buffer = new PartitionedPairBuffer[K, C]
}
}
@@ -659,7 +638,7 @@ private[spark] class ExternalSorter[K, V, C](
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
- override def writePartitionedFile(
+ def writePartitionedFile(
blockId: BlockId,
context: TaskContext,
outputFile: File): Array[Long] = {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
deleted file mode 100644
index 87a786b02d..0000000000
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ /dev/null
@@ -1,273 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.InputStream
-import java.nio.IntBuffer
-import java.util.Comparator
-
-import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.DiskBlockObjectWriter
-import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
-
-/**
- * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes
- * its records upon insert and stores them as raw bytes.
- *
- * We use two data-structures to store the contents. The serialized records are stored in a
- * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a
- * metadata buffer that stores pointers into the data buffer as well as the partition ID of each
- * record. Each entry in the metadata buffer takes up a fixed amount of space.
- *
- * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not
- * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can
- * happen without following any pointers, which should minimize cache misses.
- *
- * Currently, only sorting by partition is supported.
- *
- * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across
- * two integers:
- *
- * +-------------+------------+------------+-------------+
- * | keyStart | keyValLen | partitionId |
- * +-------------+------------+------------+-------------+
- *
- * The buffer can support up to `536870911 (2 ^ 29 - 1)` records.
- *
- * @param metaInitialRecords The initial number of entries in the metadata buffer.
- * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
- * @param serializerInstance the serializer used for serializing inserted records.
- */
-private[spark] class PartitionedSerializedPairBuffer[K, V](
- metaInitialRecords: Int,
- kvBlockSize: Int,
- serializerInstance: SerializerInstance)
- extends WritablePartitionedPairCollection[K, V] with SizeTracker {
-
- if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
- throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" +
- " Java-serialized objects.")
- }
-
- require(metaInitialRecords <= MAXIMUM_RECORDS,
- s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records")
- private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
-
- private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
- private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
- private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream)
-
- def insert(partition: Int, key: K, value: V): Unit = {
- if (metaBuffer.position == metaBuffer.capacity) {
- growMetaBuffer()
- }
-
- val keyStart = kvBuffer.size
- kvSerializationStream.writeKey[Any](key)
- kvSerializationStream.writeValue[Any](value)
- kvSerializationStream.flush()
- val keyValLen = (kvBuffer.size - keyStart).toInt
-
- // keyStart, a long, gets split across two ints
- metaBuffer.put(keyStart.toInt)
- metaBuffer.put((keyStart >> 32).toInt)
- metaBuffer.put(keyValLen)
- metaBuffer.put(partition)
- }
-
- /** Double the size of the array because we've reached capacity */
- private def growMetaBuffer(): Unit = {
- if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) {
- throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records")
- }
- val newCapacity =
- if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) {
- // Overflow
- MAXIMUM_META_BUFFER_CAPACITY
- } else {
- metaBuffer.capacity * 2
- }
- val newMetaBuffer = IntBuffer.allocate(newCapacity)
- newMetaBuffer.put(metaBuffer.array)
- metaBuffer = newMetaBuffer
- }
-
- /** Iterate through the data in a given order. For this class this is not really destructive. */
- override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
- : Iterator[((Int, K), V)] = {
- sort(keyComparator)
- val is = orderedInputStream
- val deserStream = serializerInstance.deserializeStream(is)
- new Iterator[((Int, K), V)] {
- var metaBufferPos = 0
- def hasNext: Boolean = metaBufferPos < metaBuffer.position
- def next(): ((Int, K), V) = {
- val key = deserStream.readKey[Any]().asInstanceOf[K]
- val value = deserStream.readValue[Any]().asInstanceOf[V]
- val partition = metaBuffer.get(metaBufferPos + PARTITION)
- metaBufferPos += RECORD_SIZE
- ((partition, key), value)
- }
- }
- }
-
- override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity
-
- override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
- : WritablePartitionedIterator = {
- sort(keyComparator)
- new WritablePartitionedIterator {
- // current position in the meta buffer in ints
- var pos = 0
-
- def writeNext(writer: DiskBlockObjectWriter): Unit = {
- val keyStart = getKeyStartPos(metaBuffer, pos)
- val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
- pos += RECORD_SIZE
- kvBuffer.read(keyStart, writer, keyValLen)
- writer.recordWritten()
- }
- def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
- def hasNext(): Boolean = pos < metaBuffer.position
- }
- }
-
- // Visible for testing
- def orderedInputStream: OrderedInputStream = {
- new OrderedInputStream(metaBuffer, kvBuffer)
- }
-
- private def sort(keyComparator: Option[Comparator[K]]): Unit = {
- val comparator = if (keyComparator.isEmpty) {
- new Comparator[Int]() {
- def compare(partition1: Int, partition2: Int): Int = {
- partition1 - partition2
- }
- }
- } else {
- throw new UnsupportedOperationException()
- }
-
- val sorter = new Sorter(new SerializedSortDataFormat)
- sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
- }
-}
-
-private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
- extends InputStream {
-
- import PartitionedSerializedPairBuffer._
-
- private var metaBufferPos = 0
- private var kvBufferPos =
- if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0
-
- override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
-
- override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
- if (metaBufferPos >= metaBuffer.position) {
- return -1
- }
- val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) -
- (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt
- val toRead = math.min(bytesRemainingInRecord, len)
- kvBuffer.read(kvBufferPos, bytes, offs, toRead)
- if (toRead == bytesRemainingInRecord) {
- metaBufferPos += RECORD_SIZE
- if (metaBufferPos < metaBuffer.position) {
- kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos)
- }
- } else {
- kvBufferPos += toRead
- }
- toRead
- }
-
- override def read(): Int = {
- throw new UnsupportedOperationException()
- }
-}
-
-private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] {
-
- private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
-
- /** Return the sort key for the element at the given index. */
- override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
- metaBuffer.get(pos * RECORD_SIZE + PARTITION)
- }
-
- /** Swap two elements. */
- override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
- val iOff = pos0 * RECORD_SIZE
- val jOff = pos1 * RECORD_SIZE
- System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
- System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE)
- System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
- }
-
- /** Copy a single element from src(srcPos) to dst(dstPos). */
- override def copyElement(
- src: IntBuffer,
- srcPos: Int,
- dst: IntBuffer,
- dstPos: Int): Unit = {
- val srcOff = srcPos * RECORD_SIZE
- val dstOff = dstPos * RECORD_SIZE
- System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
- }
-
- /**
- * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
- * Overlapping ranges are allowed.
- */
- override def copyRange(
- src: IntBuffer,
- srcPos: Int,
- dst: IntBuffer,
- dstPos: Int,
- length: Int): Unit = {
- val srcOff = srcPos * RECORD_SIZE
- val dstOff = dstPos * RECORD_SIZE
- System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length)
- }
-
- /**
- * Allocates a Buffer that can hold up to 'length' elements.
- * All elements of the buffer should be considered invalid until data is explicitly copied in.
- */
- override def allocate(length: Int): IntBuffer = {
- IntBuffer.allocate(length * RECORD_SIZE)
- }
-}
-
-private object PartitionedSerializedPairBuffer {
- val KEY_START = 0 // keyStart, a long, gets split across two ints
- val KEY_VAL_LEN = 2
- val PARTITION = 3
- val RECORD_SIZE = PARTITION + 1 // num ints of metadata
-
- val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1
- val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4
-
- def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = {
- val lower32 = metaBuffer.get(metaBufferPos + KEY_START)
- val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1)
- (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL)
- }
-}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 934b7e0305..232ae4d926 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -15,8 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
+import org.apache.spark.shuffle.sort.PackedRecordPointer;
import org.junit.Test;
import static org.junit.Assert.*;
@@ -24,7 +25,7 @@ import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
public class PackedRecordPointerSuite {
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 40fefe2c9d..1ef3c5ff64 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
import java.util.Arrays;
import java.util.Random;
@@ -30,7 +30,7 @@ import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
-public class UnsafeShuffleInMemorySorterSuite {
+public class ShuffleInMemorySorterSuite {
private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
final byte[] strBytes = new byte[strLength];
@@ -40,8 +40,8 @@ public class UnsafeShuffleInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
- final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
- final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
+ final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
assert(!iter.hasNext());
}
@@ -62,7 +62,7 @@ public class UnsafeShuffleInMemorySorterSuite {
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
final Object baseObject = dataPage.getBaseObject();
- final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Write the records into the data page and store pointers into the sorter
@@ -79,7 +79,7 @@ public class UnsafeShuffleInMemorySorterSuite {
}
// Sort the records
- final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
int prevPartitionId = -1;
Arrays.sort(dataToSort);
for (int i = 0; i < dataToSort.length; i++) {
@@ -103,7 +103,7 @@ public class UnsafeShuffleInMemorySorterSuite {
@Test
public void testSortingManyNumbers() throws Exception {
- UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+ ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
int[] numbersToSort = new int[128000];
Random random = new Random(16);
for (int i = 0; i < numbersToSort.length; i++) {
@@ -112,7 +112,7 @@ public class UnsafeShuffleInMemorySorterSuite {
}
Arrays.sort(numbersToSort);
int[] sorterResult = new int[numbersToSort.length];
- UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
int j = 0;
while (iter.hasNext()) {
iter.loadNext();
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index d218344cd4..29d9823b1f 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.shuffle.sort;
import java.io.*;
import java.nio.ByteBuffer;
@@ -23,7 +23,6 @@ import java.util.*;
import scala.*;
import scala.collection.Iterator;
-import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;
import com.google.common.collect.Iterators;
@@ -56,6 +55,7 @@ import org.apache.spark.serializer.*;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
@@ -204,7 +204,7 @@ public class UnsafeShuffleWriterSuite {
shuffleBlockResolver,
taskMemoryManager,
shuffleMemoryManager,
- new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+ new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
0, // map id
taskContext,
conf
@@ -461,7 +461,7 @@ public class UnsafeShuffleWriterSuite {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite =
new ArrayList<Product2<Object, Object>>();
- final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+ final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
new Random(42).nextBytes(bytes);
dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
writer.write(dataToWrite.iterator());
@@ -516,7 +516,7 @@ public class UnsafeShuffleWriterSuite {
shuffleBlockResolver,
taskMemoryManager,
shuffleMemoryManager,
- new UnsafeShuffleHandle<>(0, 1, shuffleDep),
+ new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf);
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 63358172ea..b8ab227517 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -17,13 +17,78 @@
package org.apache.spark
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with sort-based shuffle.
+ private var tempDir: File = _
+
override def beforeAll() {
conf.set("spark.shuffle.manager", "sort")
}
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ conf.set("spark.local.dir", tempDir.getAbsolutePath)
+ }
+
+ override def afterEach(): Unit = {
+ try {
+ Utils.deleteRecursively(tempDir)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") {
+ sc = new SparkContext("local", "test", conf)
+ // Create a shuffled RDD and verify that it actually uses the new serialized map output path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new KryoSerializer(conf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+ ensureFilesAreCleanedUp(shuffledRdd)
+ }
+
+ test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") {
+ sc = new SparkContext("local", "test", conf)
+ // Create a shuffled RDD and verify that it actually uses the old deserialized map output path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new JavaSerializer(conf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+ ensureFilesAreCleanedUp(shuffledRdd)
+ }
+
+ private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = {
+ def getAllFiles: Set[File] =
+ FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+ val filesBeforeShuffle = getAllFiles
+ // Force the shuffle to be performed
+ shuffledRdd.count()
+ // Ensure that the shuffle actually created files that will need to be cleaned up
+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+ filesCreatedByShuffle.map(_.getName) should be
+ Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+ // Check that the cleanup actually removes the files
+ sc.env.blockManager.master.removeShuffle(0, blocking = true)
+ for (file <- filesCreatedByShuffle) {
+ assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 5b01ddb298..3816b8c4a0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1062,10 +1062,10 @@ class DAGSchedulerSuite
*/
test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") {
val firstRDD = new MyRDD(sc, 3, Nil)
- val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+ val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
val firstShuffleId = firstShuffleDep.shuffleId
val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep))
- val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
submit(reduceRdd, Array(0))
@@ -1175,7 +1175,7 @@ class DAGSchedulerSuite
*/
test("register map outputs correctly after ExecutorLost and task Resubmitted") {
val firstRDD = new MyRDD(sc, 3, Nil)
- val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+ val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep))
submit(reduceRdd, Array(0))
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 341f56df2d..b92a302806 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark._
import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
-import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.serializer.{JavaSerializer, SerializerInstance}
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@@ -42,25 +43,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
@Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
@Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
@Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
private var taskMetrics: TaskMetrics = _
- private var shuffleWriteMetrics: ShuffleWriteMetrics = _
private var tempDir: File = _
private var outputFile: File = _
private val conf: SparkConf = new SparkConf(loadDefaults = false)
private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
- private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
- private val serializer: Serializer = new JavaSerializer(conf)
+ private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
override def beforeEach(): Unit = {
tempDir = Utils.createTempDir()
outputFile = File.createTempFile("shuffle", null, tempDir)
- shuffleWriteMetrics = new ShuffleWriteMetrics
taskMetrics = new TaskMetrics
- taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
MockitoAnnotations.initMocks(this)
+ shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
+ shuffleId = 0,
+ numMaps = 2,
+ dependency = dependency
+ )
+ when(dependency.partitioner).thenReturn(new HashPartitioner(7))
+ when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+ when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(blockManager.getDiskWriter(
any[BlockId],
@@ -107,18 +114,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("write empty iterator") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
- new SparkConf(loadDefaults = false),
blockManager,
- new HashPartitioner(7),
- shuffleWriteMetrics,
- serializer
+ blockResolver,
+ shuffleHandle,
+ 0, // MapId
+ taskContext,
+ conf
)
- writer.insertAll(Iterator.empty)
- val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
- assert(partitionLengths.sum === 0)
+ writer.write(Iterator.empty)
+ writer.stop( /* success = */ true)
+ assert(writer.getPartitionLengths.sum === 0)
assert(outputFile.exists())
assert(outputFile.length() === 0)
assert(temporaryFilesCreated.isEmpty)
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
assert(taskMetrics.diskBytesSpilled === 0)
@@ -129,17 +138,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
def records: Iterator[(Int, Int)] =
Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
val writer = new BypassMergeSortShuffleWriter[Int, Int](
- new SparkConf(loadDefaults = false),
blockManager,
- new HashPartitioner(7),
- shuffleWriteMetrics,
- serializer
+ blockResolver,
+ shuffleHandle,
+ 0, // MapId
+ taskContext,
+ conf
)
- writer.insertAll(records)
+ writer.write(records)
+ writer.stop( /* success = */ true)
assert(temporaryFilesCreated.nonEmpty)
- val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
- assert(partitionLengths.sum === outputFile.length())
+ assert(writer.getPartitionLengths.sum === outputFile.length())
assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
assert(taskMetrics.diskBytesSpilled === 0)
@@ -148,14 +159,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("cleanup of intermediate files after errors") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
- new SparkConf(loadDefaults = false),
blockManager,
- new HashPartitioner(7),
- shuffleWriteMetrics,
- serializer
+ blockResolver,
+ shuffleHandle,
+ 0, // MapId
+ taskContext,
+ conf
)
intercept[SparkException] {
- writer.insertAll((0 until 100000).iterator.map(i => {
+ writer.write((0 until 100000).iterator.map(i => {
if (i == 99990) {
throw new SparkException("Intentional failure")
}
@@ -163,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
}))
}
assert(temporaryFilesCreated.nonEmpty)
- writer.stop()
+ writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0)
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
index 6727934d8c..8744a072cb 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe
+package org.apache.spark.shuffle.sort
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
@@ -29,9 +29,9 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
* Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
* performed in other suites.
*/
-class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
+class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
- import UnsafeShuffleManager.canUseUnsafeShuffle
+ import SortShuffleManager.canUseSerializedShuffle
private class RuntimeExceptionAnswer extends Answer[Object] {
override def answer(invocation: InvocationOnMock): Object = {
@@ -55,10 +55,10 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
dep
}
- test("supported shuffle dependencies") {
+ test("supported shuffle dependencies for serialized shuffle") {
val kryo = Some(new KryoSerializer(new SparkConf()))
- assert(canUseUnsafeShuffle(shuffleDep(
+ assert(canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
keyOrdering = None,
@@ -68,7 +68,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
when(rangePartitioner.numPartitions).thenReturn(2)
- assert(canUseUnsafeShuffle(shuffleDep(
+ assert(canUseSerializedShuffle(shuffleDep(
partitioner = rangePartitioner,
serializer = kryo,
keyOrdering = None,
@@ -77,7 +77,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
)))
// Shuffles with key orderings are supported as long as no aggregator is specified
- assert(canUseUnsafeShuffle(shuffleDep(
+ assert(canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
keyOrdering = Some(mock(classOf[Ordering[Any]])),
@@ -87,12 +87,12 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
}
- test("unsupported shuffle dependencies") {
+ test("unsupported shuffle dependencies for serialized shuffle") {
val kryo = Some(new KryoSerializer(new SparkConf()))
val java = Some(new JavaSerializer(new SparkConf()))
// We only support serializers that support object relocation
- assert(!canUseUnsafeShuffle(shuffleDep(
+ assert(!canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = java,
keyOrdering = None,
@@ -100,9 +100,11 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
mapSideCombine = false
)))
- // We do not support shuffles with more than 16 million output partitions
- assert(!canUseUnsafeShuffle(shuffleDep(
- partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
+ // The serialized shuffle path do not support shuffles with more than 16 million output
+ // partitions, due to a limitation in its sorter implementation.
+ assert(!canUseSerializedShuffle(shuffleDep(
+ partitioner = new HashPartitioner(
+ SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1),
serializer = kryo,
keyOrdering = None,
aggregator = None,
@@ -110,14 +112,14 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
)))
// We do not support shuffles that perform aggregation
- assert(!canUseUnsafeShuffle(shuffleDep(
+ assert(!canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
keyOrdering = None,
aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
mapSideCombine = false
)))
- assert(!canUseUnsafeShuffle(shuffleDep(
+ assert(!canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
keyOrdering = Some(mock(classOf[Ordering[Any]])),
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
deleted file mode 100644
index 34b4984f12..0000000000
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort
-
-import org.mockito.Mockito._
-
-import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite}
-
-class SortShuffleWriterSuite extends SparkFunSuite {
-
- import SortShuffleWriter._
-
- test("conditions for bypassing merge-sort") {
- val conf = new SparkConf(loadDefaults = false)
- val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
- val ord = implicitly[Ordering[Int]]
-
- // Numbers of partitions that are above and below the default bypassMergeThreshold
- val FEW_PARTITIONS = 50
- val MANY_PARTITIONS = 10000
-
- // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
- assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
- assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
-
- // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
- assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
- assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
deleted file mode 100644
index 259020a2dd..0000000000
--- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import java.io.File
-
-import scala.collection.JavaConverters._
-
-import org.apache.commons.io.FileUtils
-import org.apache.commons.io.filefilter.TrueFileFilter
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
-import org.apache.spark.rdd.ShuffledRDD
-import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-import org.apache.spark.util.Utils
-
-class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
-
- // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
-
- override def beforeAll() {
- conf.set("spark.shuffle.manager", "tungsten-sort")
- }
-
- test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
- val tmpDir = Utils.createTempDir()
- try {
- val myConf = conf.clone()
- .set("spark.local.dir", tmpDir.getAbsolutePath)
- sc = new SparkContext("local", "test", myConf)
- // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
- val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
- val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
- .setSerializer(new KryoSerializer(myConf))
- val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
- assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
- def getAllFiles: Set[File] =
- FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
- val filesBeforeShuffle = getAllFiles
- // Force the shuffle to be performed
- shuffledRdd.count()
- // Ensure that the shuffle actually created files that will need to be cleaned up
- val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
- filesCreatedByShuffle.map(_.getName) should be
- Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
- // Check that the cleanup actually removes the files
- sc.env.blockManager.master.removeShuffle(0, blocking = true)
- for (file <- filesCreatedByShuffle) {
- assert (!file.exists(), s"Shuffle file $file was not cleaned up")
- }
- } finally {
- Utils.deleteRecursively(tmpDir)
- }
- }
-
- test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
- val tmpDir = Utils.createTempDir()
- try {
- val myConf = conf.clone()
- .set("spark.local.dir", tmpDir.getAbsolutePath)
- sc = new SparkContext("local", "test", myConf)
- // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
- val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
- val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
- .setSerializer(new JavaSerializer(myConf))
- val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
- assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
- def getAllFiles: Set[File] =
- FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
- val filesBeforeShuffle = getAllFiles
- // Force the shuffle to be performed
- shuffledRdd.count()
- // Ensure that the shuffle actually created files that will need to be cleaned up
- val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
- filesCreatedByShuffle.map(_.getName) should be
- Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
- // Check that the cleanup actually removes the files
- sc.env.blockManager.master.removeShuffle(0, blocking = true)
- for (file <- filesCreatedByShuffle) {
- assert (!file.exists(), s"Shuffle file $file was not cleaned up")
- }
- } finally {
- Utils.deleteRecursively(tmpDir)
- }
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
deleted file mode 100644
index 05306f4088..0000000000
--- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.nio.ByteBuffer
-
-import org.scalatest.Matchers._
-
-import org.apache.spark.SparkFunSuite
-
-class ChainedBufferSuite extends SparkFunSuite {
- test("write and read at start") {
- // write from start of source array
- val buffer = new ChainedBuffer(8)
- buffer.capacity should be (0)
- verifyWriteAndRead(buffer, 0, 0, 0, 4)
- buffer.capacity should be (8)
-
- // write from middle of source array
- verifyWriteAndRead(buffer, 0, 5, 0, 4)
- buffer.capacity should be (8)
-
- // read to middle of target array
- verifyWriteAndRead(buffer, 0, 0, 5, 4)
- buffer.capacity should be (8)
-
- // write up to border
- verifyWriteAndRead(buffer, 0, 0, 0, 8)
- buffer.capacity should be (8)
-
- // expand into second buffer
- verifyWriteAndRead(buffer, 0, 0, 0, 12)
- buffer.capacity should be (16)
-
- // expand into multiple buffers
- verifyWriteAndRead(buffer, 0, 0, 0, 28)
- buffer.capacity should be (32)
- }
-
- test("write and read at middle") {
- val buffer = new ChainedBuffer(8)
-
- // fill to a middle point
- verifyWriteAndRead(buffer, 0, 0, 0, 3)
-
- // write from start of source array
- verifyWriteAndRead(buffer, 3, 0, 0, 4)
- buffer.capacity should be (8)
-
- // write from middle of source array
- verifyWriteAndRead(buffer, 3, 5, 0, 4)
- buffer.capacity should be (8)
-
- // read to middle of target array
- verifyWriteAndRead(buffer, 3, 0, 5, 4)
- buffer.capacity should be (8)
-
- // write up to border
- verifyWriteAndRead(buffer, 3, 0, 0, 5)
- buffer.capacity should be (8)
-
- // expand into second buffer
- verifyWriteAndRead(buffer, 3, 0, 0, 12)
- buffer.capacity should be (16)
-
- // expand into multiple buffers
- verifyWriteAndRead(buffer, 3, 0, 0, 28)
- buffer.capacity should be (32)
- }
-
- test("write and read at later buffer") {
- val buffer = new ChainedBuffer(8)
-
- // fill to a middle point
- verifyWriteAndRead(buffer, 0, 0, 0, 11)
-
- // write from start of source array
- verifyWriteAndRead(buffer, 11, 0, 0, 4)
- buffer.capacity should be (16)
-
- // write from middle of source array
- verifyWriteAndRead(buffer, 11, 5, 0, 4)
- buffer.capacity should be (16)
-
- // read to middle of target array
- verifyWriteAndRead(buffer, 11, 0, 5, 4)
- buffer.capacity should be (16)
-
- // write up to border
- verifyWriteAndRead(buffer, 11, 0, 0, 5)
- buffer.capacity should be (16)
-
- // expand into second buffer
- verifyWriteAndRead(buffer, 11, 0, 0, 12)
- buffer.capacity should be (24)
-
- // expand into multiple buffers
- verifyWriteAndRead(buffer, 11, 0, 0, 28)
- buffer.capacity should be (40)
- }
-
-
- // Used to make sure we're writing different bytes each time
- var rangeStart = 0
-
- /**
- * @param buffer The buffer to write to and read from.
- * @param offsetInBuffer The offset to write to in the buffer.
- * @param offsetInSource The offset in the array that the bytes are written from.
- * @param offsetInTarget The offset in the array to read the bytes into.
- * @param length The number of bytes to read and write
- */
- def verifyWriteAndRead(
- buffer: ChainedBuffer,
- offsetInBuffer: Int,
- offsetInSource: Int,
- offsetInTarget: Int,
- length: Int): Unit = {
- val source = new Array[Byte](offsetInSource + length)
- (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource)
- buffer.write(offsetInBuffer, source, offsetInSource, length)
- val target = new Array[Byte](offsetInTarget + length)
- buffer.read(offsetInBuffer, target, offsetInTarget, length)
- ByteBuffer.wrap(source, offsetInSource, length) should be
- (ByteBuffer.wrap(target, offsetInTarget, length))
-
- rangeStart += 100
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
deleted file mode 100644
index 3b67f62064..0000000000
--- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-
-import com.google.common.io.ByteStreams
-
-import org.mockito.Matchers.any
-import org.mockito.Mockito._
-import org.mockito.Mockito.RETURNS_SMART_NULLS
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-import org.scalatest.Matchers._
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.DiskBlockObjectWriter
-
-class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
- test("OrderedInputStream single record") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct = SomeStruct("something", 5)
- buffer.insert(4, 10, struct)
-
- val bytes = ByteStreams.toByteArray(buffer.orderedInputStream)
-
- val baos = new ByteArrayOutputStream()
- val stream = serializerInstance.serializeStream(baos)
- stream.writeObject(10)
- stream.writeObject(struct)
- stream.close()
-
- baos.toByteArray should be (bytes)
- }
-
- test("insert single record") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct = SomeStruct("something", 5)
- buffer.insert(4, 10, struct)
- val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
- elements.size should be (1)
- elements.head should be (((4, 10), struct))
- }
-
- test("insert multiple records") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct1 = SomeStruct("something1", 8)
- buffer.insert(6, 1, struct1)
- val struct2 = SomeStruct("something2", 9)
- buffer.insert(4, 2, struct2)
- val struct3 = SomeStruct("something3", 10)
- buffer.insert(5, 3, struct3)
-
- val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
- elements.size should be (3)
- elements(0) should be (((4, 2), struct2))
- elements(1) should be (((5, 3), struct3))
- elements(2) should be (((6, 1), struct1))
- }
-
- test("write single record") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct = SomeStruct("something", 5)
- buffer.insert(4, 10, struct)
- val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val (writer, baos) = createMockWriter()
- assert(it.hasNext)
- it.nextPartition should be (4)
- it.writeNext(writer)
- assert(!it.hasNext)
-
- val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
- stream.readObject[AnyRef]() should be (10)
- stream.readObject[AnyRef]() should be (struct)
- }
-
- test("write multiple records") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct1 = SomeStruct("something1", 8)
- buffer.insert(6, 1, struct1)
- val struct2 = SomeStruct("something2", 9)
- buffer.insert(4, 2, struct2)
- val struct3 = SomeStruct("something3", 10)
- buffer.insert(5, 3, struct3)
-
- val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val (writer, baos) = createMockWriter()
- assert(it.hasNext)
- it.nextPartition should be (4)
- it.writeNext(writer)
- assert(it.hasNext)
- it.nextPartition should be (5)
- it.writeNext(writer)
- assert(it.hasNext)
- it.nextPartition should be (6)
- it.writeNext(writer)
- assert(!it.hasNext)
-
- val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
- val iter = stream.asIterator
- iter.next() should be (2)
- iter.next() should be (struct2)
- iter.next() should be (3)
- iter.next() should be (struct3)
- iter.next() should be (1)
- iter.next() should be (struct1)
- assert(!iter.hasNext)
- }
-
- def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
- val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
- val baos = new ByteArrayOutputStream()
- when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
- override def answer(invocationOnMock: InvocationOnMock): Unit = {
- val args = invocationOnMock.getArguments
- val bytes = args(0).asInstanceOf[Array[Byte]]
- val offset = args(1).asInstanceOf[Int]
- val length = args(2).asInstanceOf[Int]
- baos.write(bytes, offset, length)
- }
- })
- (writer, baos)
- }
-}
-
-case class SomeStruct(str: String, num: Int)