aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-05-30 15:27:51 -0700
committerReynold Xin <rxin@databricks.com>2015-05-30 15:27:51 -0700
commita6430028ecd7a6130f1eb15af9ec00e242c46725 (patch)
treee039aabe2fcdd77d97c4d837c3ce194d638ae66c /core
parent7716a5a1ec8ff8dc24e0146f8ead2f51da6512ad (diff)
downloadspark-a6430028ecd7a6130f1eb15af9ec00e242c46725.tar.gz
spark-a6430028ecd7a6130f1eb15af9ec00e242c46725.tar.bz2
spark-a6430028ecd7a6130f1eb15af9ec00e242c46725.zip
[SPARK-7855] Move bypassMergeSort-handling from ExternalSorter to own component
Spark's `ExternalSorter` writes shuffle output files during sort-based shuffle. Sort-shuffle contains a configuration, `spark.shuffle.sort.bypassMergeThreshold`, which causes ExternalSorter to skip sorting and merging and simply write separate files per partition, which are then concatenated together to form the final map output file. The code paths used during this bypass are almost completely separate from ExternalSorter's other code paths, so refactoring them into a separate file can significantly simplify the code. In addition to re-arranging code, this patch deletes a bunch of dead code. The main entry point into ExternalSorter is `insertAll()` and in SPARK-4479 / #3422 this method was modified to completely bypass in-memory buffering of records when `bypassMergeSort` takes effect. As a result, some of the spilling and merging code paths will no longer be called when `bypassMergeSort` is used, so we should be able to safely remove that code. There's an open JIRA ([SPARK-6026](https://issues.apache.org/jira/browse/SPARK-6026)) for removing the `bypassMergeThreshold` parameter and code paths; I have not done that here, but the changes in this patch will make removing that parameter significantly easier if we ever decide to do that. This patch also makes several improvements to shuffle-related tests and adds more defensive checks to certain shuffle classes: - DiskBlockObjectWriter now throws an exception if `fileSegment()` is called before `commitAndClose()` has been called. - DiskBlockObjectWriter's close methods are now idempotent, so calling any of the close methods twice in a row will no longer result in incorrect shuffle write metrics changes. Calling `revertPartialWritesAndClose()` on a closed DiskBlockObjectWriter now has no effect (before, it might mess up the metrics). - The end-to-end shuffle record count metrics tests have been moved from InputOutputMetricsSuite to ShuffleSuite. This means that these tests will now be run against all shuffle implementations rather than just the default shuffle configuration. - The end-to-end metrics tests now include a test of a job which performs aggregation in the shuffle. - Our tests now check that `shuffleBytesWritten == totalShuffleBytesRead`. - FileSegment now throws IllegalArgumentException if it is constructed with a negative length or offset. Author: Josh Rosen <joshrosen@databricks.com> Closes #6397 from JoshRosen/external-sorter-bypass-cleanup and squashes the following commits: bf3f3f6 [Josh Rosen] Merge remote-tracking branch 'origin/master' into external-sorter-bypass-cleanup 8b216c4 [Josh Rosen] Guard against negative offsets and lengths in FileSegment 03f35a4 [Josh Rosen] Minor fix to cleanup logic. b5cc35b [Josh Rosen] Move shuffle metrics tests to ShuffleSuite. 8b8fb9e [Josh Rosen] Add more tests + defensive programming to DiskBlockObjectWriter. 16564eb [Josh Rosen] Guard against calling fileSegment() before commitAndClose() has been called. 96811b4 [Josh Rosen] Remove confusing taskMetrics.shuffleWriteMetrics() optional call 8522b6a [Josh Rosen] Do not perform a map-side sort unless we're also doing map-side aggregation 08e40f3 [Josh Rosen] Remove excessively clever (and wrong) implementation of newBuffer() d7f9938 [Josh Rosen] Add missing overrides; fix compilation 71d76ff [Josh Rosen] Update Javadoc bf0d98f [Josh Rosen] Add comment to clarify confusing factory code 5197f73 [Josh Rosen] Add missing private[this] 30ef2c8 [Josh Rosen] Convert BypassMergeSortShuffleWriter to Java bc1a820 [Josh Rosen] Fix bug when aggregator is used but map-side combine is disabled 0d3dcc0 [Josh Rosen] Remove unnecessary overloaded methods 25b964f [Josh Rosen] Rename SortShuffleSorter to SortShuffleFileWriter 0d9848c [Josh Rosen] Make it more clear that curWriteMetrics is now only used for spill metrics 7af7aea [Josh Rosen] Combine spill() and spillToMergeableFile() 6320112 [Josh Rosen] Add missing negation in deletion success check. d267e0d [Josh Rosen] Fix style issue 7f15f7b [Josh Rosen] Back out extra cleanup-handling code, since this is already covered in stop() 25aa3bd [Josh Rosen] Make sure to delete outputFile after errors. 931ca68 [Josh Rosen] Refactor tests. 6a35716 [Josh Rosen] Refactor logic for deciding when to bypass 4b03539 [Josh Rosen] Move conf prior to first use 1265b25 [Josh Rosen] Fix some style errors and comments. 02355ef [Josh Rosen] More simplification d4cb536 [Josh Rosen] Delete more unused code bb96678 [Josh Rosen] Add missing interface file b6cc1eb [Josh Rosen] Realize that bypass never buffers; proceed to delete tons of code 6185ee2 [Josh Rosen] WIP towards moving bypass code into own file. 8d0678c [Josh Rosen] Move diskBytesSpilled getter next to variable 19bccd6 [Josh Rosen] Remove duplicated buffer creation code. 18959bb [Josh Rosen] Move comparator methods closer together.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java184
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java53
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala260
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala36
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala65
-rw-r--r--core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala171
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala46
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala97
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala130
17 files changed, 738 insertions, 423 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
new file mode 100644
index 0000000000..d3d6280284
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.Partitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
+import org.apache.spark.util.Utils;
+
+/**
+ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
+ * writes incoming records to separate files, one file per reduce partition, then concatenates these
+ * per-partition files to form a single output file, regions of which are served to reducers.
+ * Records are not buffered in memory. This is essentially identical to
+ * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format
+ * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}.
+ * <p>
+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it
+ * simultaneously opens separate serializers and file streams for all partitions. As a result,
+ * {@link SortShuffleManager} only selects this write path when
+ * <ul>
+ * <li>no Ordering is specified,</li>
+ * <li>no Aggregator is specific, and</li>
+ * <li>the number of partitions is less than
+ * <code>spark.shuffle.sort.bypassMergeThreshold</code>.</li>
+ * </ul>
+ *
+ * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was
+ * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details.
+ * <p>
+ * There have been proposals to completely remove this code path; see SPARK-6026 for details.
+ */
+final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
+
+ private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
+
+ private final int fileBufferSize;
+ private final boolean transferToEnabled;
+ private final int numPartitions;
+ private final BlockManager blockManager;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final Serializer serializer;
+
+ /** Array of file writers, one for each partition */
+ private BlockObjectWriter[] partitionWriters;
+
+ public BypassMergeSortShuffleWriter(
+ SparkConf conf,
+ BlockManager blockManager,
+ Partitioner partitioner,
+ ShuffleWriteMetrics writeMetrics,
+ Serializer serializer) {
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
+ this.numPartitions = partitioner.numPartitions();
+ this.blockManager = blockManager;
+ this.partitioner = partitioner;
+ this.writeMetrics = writeMetrics;
+ this.serializer = serializer;
+ }
+
+ @Override
+ public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
+ assert (partitionWriters == null);
+ if (!records.hasNext()) {
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new BlockObjectWriter[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open();
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime);
+
+ while (records.hasNext()) {
+ final Product2<K, V> record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
+ }
+
+ for (BlockObjectWriter writer : partitionWriters) {
+ writer.commitAndClose();
+ }
+ }
+
+ @Override
+ public long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException {
+ // Track location of the partition starts in the output file
+ final long[] lengths = new long[numPartitions];
+ if (partitionWriters == null) {
+ // We were passed an empty iterator
+ return lengths;
+ }
+
+ final FileOutputStream out = new FileOutputStream(outputFile, true);
+ final long writeStartTime = System.nanoTime();
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < numPartitions; i++) {
+ final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file());
+ boolean copyThrewException = true;
+ try {
+ lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) {
+ logger.error("Unable to delete file for partition {}", i);
+ }
+ }
+ threwException = false;
+ } finally {
+ Closeables.close(out, threwException);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+ }
+ partitionWriters = null;
+ return lengths;
+ }
+
+ @Override
+ public void stop() throws IOException {
+ if (partitionWriters != null) {
+ try {
+ final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
+ for (BlockObjectWriter writer : partitionWriters) {
+ // This method explicitly does _not_ throw exceptions:
+ writer.revertPartialWritesAndClose();
+ if (!diskBlockManager.getFile(writer.blockId()).delete()) {
+ logger.error("Error while deleting file for block {}", writer.blockId());
+ }
+ }
+ } finally {
+ partitionWriters = null;
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
new file mode 100644
index 0000000000..656ea0401a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.collection.Iterator;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.TaskContext;
+import org.apache.spark.storage.BlockId;
+
+/**
+ * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
+ */
+@Private
+public interface SortShuffleFileWriter<K, V> {
+
+ void insertAll(Iterator<Product2<K, V>> records) throws IOException;
+
+ /**
+ * Write all the data added into this shuffle sorter into a file in the disk store. This is
+ * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+ * binary files if we decided to avoid merge-sorting.
+ *
+ * @param blockId block ID to write to. The index file will be blockId.name + ".index".
+ * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+ */
+ long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException;
+
+ void stop() throws IOException;
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index c9dd6bfc4c..5865e7640c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -17,9 +17,10 @@
package org.apache.spark.shuffle.sort
-import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
+import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
@@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
- private var sorter: ExternalSorter[K, V, _] = null
+ private var sorter: SortShuffleFileWriter[K, V] = null
// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
@@ -49,18 +50,27 @@ private[spark] class SortShuffleWriter[K, V, C](
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
- if (dep.mapSideCombine) {
+ sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
- sorter = new ExternalSorter[K, V, C](
+ new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
- sorter.insertAll(records)
+ } else if (SortShuffleWriter.shouldBypassMergeSort(
+ SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
+ // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+ // need local aggregation and sorting, write numPartitions files directly and just concatenate
+ // them at the end. This avoids doing serialization and deserialization twice to merge
+ // together the spilled files, which would happen with the normal code path. The downside is
+ // having multiple files open at a time and thus more memory allocated to buffers.
+ new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
+ writeMetrics, Serializer.getSerializer(dep.serializer))
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
- sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer)
- sorter.insertAll(records)
+ new ExternalSorter[K, V, V](
+ aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
+ sorter.insertAll(records)
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
@@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C](
}
}
+private[spark] object SortShuffleWriter {
+ def shouldBypassMergeSort(
+ conf: SparkConf,
+ numPartitions: Int,
+ aggregator: Option[Aggregator[_, _, _]],
+ keyOrdering: Option[Ordering[_]]): Boolean = {
+ val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index a33f22ef52..7eeabd1e04 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter(
private var objOut: SerializationStream = null
private var initialized = false
private var hasBeenClosed = false
+ private var commitAndCloseHasBeenCalled = false
/**
* Cursors used to represent positions in the file.
@@ -167,20 +168,22 @@ private[spark] class DiskBlockObjectWriter(
objOut.flush()
bs.flush()
close()
+ finalPosition = file.length()
+ // In certain compression codecs, more bytes are written after close() is called
+ writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+ } else {
+ finalPosition = file.length()
}
- finalPosition = file.length()
- // In certain compression codecs, more bytes are written after close() is called
- writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+ commitAndCloseHasBeenCalled = true
}
// Discard current writes. We do this by flushing the outstanding writes and then
// truncating the file to its initial position.
override def revertPartialWritesAndClose() {
try {
- writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
- writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
-
if (initialized) {
+ writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
+ writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
objOut.flush()
bs.flush()
close()
@@ -228,6 +231,10 @@ private[spark] class DiskBlockObjectWriter(
}
override def fileSegment(): FileSegment = {
+ if (!commitAndCloseHasBeenCalled) {
+ throw new IllegalStateException(
+ "fileSegment() is only valid after commitAndClose() has been called")
+ }
new FileSegment(file, initialPosition, finalPosition - initialPosition)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
index 95e2d688d9..021a9facfb 100644
--- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -24,6 +24,8 @@ import java.io.File
* based off an offset and a length.
*/
private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) {
+ require(offset >= 0, s"File segment offset cannot be negative (got $offset)")
+ require(length >= 0, s"File segment length cannot be negative (got $length)")
override def toString: String = {
"(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3b9d14f937..ef2dbb7ff0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -23,12 +23,14 @@ import java.util.Comparator
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
+import com.google.common.annotations.VisibleForTesting
import com.google.common.io.ByteStreams
import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.storage.{BlockObjectWriter, BlockId}
+import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
+import org.apache.spark.storage.{BlockId, BlockObjectWriter}
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -84,35 +86,40 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
* each other for equality to merge values.
*
* - Users are expected to call stop() at the end to delete all the intermediate files.
- *
- * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is
- * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to
- * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can
- * then concatenate these files to produce a single sorted file, without having to serialize and
- * de-serialize each item twice (as is needed during the merge). This speeds up the map side of
- * groupBy, sort, etc operations since they do no partial aggregation.
*/
private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
- extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] {
+ extends Logging
+ with Spillable[WritablePartitionedPairCollection[K, C]]
+ with SortShuffleFileWriter[K, V] {
+
+ private val conf = SparkEnv.get.conf
private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
private val shouldPartition = numPartitions > 1
+ private def getPartition(key: K): Int = {
+ if (shouldPartition) partitioner.get.getPartition(key) else 0
+ }
+
+ // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
+ // As a sanity check, make sure that we're not handling a shuffle which should use that path.
+ if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
+ throw new IllegalArgumentException("ExternalSorter should not be used to handle "
+ + " a sort that the BypassMergeSortShuffleWriter should handle")
+ }
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(serializer)
private val serInstance = ser.newInstance()
- private val conf = SparkEnv.get.conf
private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
- private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
// Size of object batches when reading/writing from serializers.
//
@@ -123,43 +130,28 @@ private[spark] class ExternalSorter[K, V, C](
// grow internal data structures by growing + copying every time the number of objects doubles.
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
- private def getPartition(key: K): Int = {
- if (shouldPartition) partitioner.get.getPartition(key) else 0
- }
-
- private val metaInitialRecords = 256
- private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
private val useSerializedPairBuffer =
- !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
- ser.supportsRelocationOfSerializedObjects
-
+ ordering.isEmpty &&
+ conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
+ ser.supportsRelocationOfSerializedObjects
+ private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
+ private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
+ if (useSerializedPairBuffer) {
+ new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
+ } else {
+ new PartitionedPairBuffer[K, C]
+ }
+ }
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
private var map = new PartitionedAppendOnlyMap[K, C]
- private var buffer = if (useSerializedPairBuffer) {
- new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
- } else {
- new PartitionedPairBuffer[K, C]
- }
+ private var buffer = newBuffer()
// Total spilling statistics
private var _diskBytesSpilled = 0L
+ def diskBytesSpilled: Long = _diskBytesSpilled
- // Write metrics for current spill
- private var curWriteMetrics: ShuffleWriteMetrics = _
-
- // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
- // local aggregation and sorting, write numPartitions files directly and just concatenate them
- // at the end. This avoids doing serialization and deserialization twice to merge together the
- // spilled files, which would happen with the normal code path. The downside is having multiple
- // files open at a time and thus more memory allocated to buffers.
- private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- private val bypassMergeSort =
- (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty)
-
- // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled
- private var partitionWriters: Array[BlockObjectWriter] = null
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -174,6 +166,14 @@ private[spark] class ExternalSorter[K, V, C](
}
})
+ private def comparator: Option[Comparator[K]] = {
+ if (ordering.isDefined || aggregator.isDefined) {
+ Some(keyComparator)
+ } else {
+ None
+ }
+ }
+
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
@@ -182,9 +182,10 @@ private[spark] class ExternalSorter[K, V, C](
blockId: BlockId,
serializerBatchSizes: Array[Long],
elementsPerPartition: Array[Long])
+
private val spills = new ArrayBuffer[SpilledFile]
- def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
@@ -202,15 +203,6 @@ private[spark] class ExternalSorter[K, V, C](
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
- } else if (bypassMergeSort) {
- // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
- if (records.hasNext) {
- spillToPartitionFiles(
- WritablePartitionedIterator.fromIterator(records.map { kv =>
- ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
- })
- )
- }
} else {
// Stick values into our buffer
while (records.hasNext) {
@@ -238,46 +230,33 @@ private[spark] class ExternalSorter[K, V, C](
}
} else {
if (maybeSpill(buffer, buffer.estimateSize())) {
- buffer = if (useSerializedPairBuffer) {
- new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
- } else {
- new PartitionedPairBuffer[K, C]
- }
+ buffer = newBuffer()
}
}
}
/**
- * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
- */
- override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
- if (bypassMergeSort) {
- spillToPartitionFiles(collection)
- } else {
- spillToMergeableFile(collection)
- }
- }
-
- /**
- * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
- * We add this file into spilledFiles to find it later.
- *
- * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles()
- * is used to write files for each partition.
+ * Spill our in-memory collection to a sorted file that we can merge later.
+ * We add this file into `spilledFiles` to find it later.
*
* @param collection whichever collection we're using (map or buffer)
*/
- private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = {
- assert(!bypassMergeSort)
-
+ override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
- curWriteMetrics = new ShuffleWriteMetrics()
- var writer = blockManager.getDiskWriter(
- blockId, file, serInstance, fileBufferSize, curWriteMetrics)
- var objectsWritten = 0 // Objects written since the last flush
+
+ // These variables are reset after each flush
+ var objectsWritten: Long = 0
+ var spillMetrics: ShuffleWriteMetrics = null
+ var writer: BlockObjectWriter = null
+ def openWriter(): Unit = {
+ assert (writer == null && spillMetrics == null)
+ spillMetrics = new ShuffleWriteMetrics
+ writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
+ }
+ openWriter()
// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]
@@ -291,8 +270,9 @@ private[spark] class ExternalSorter[K, V, C](
val w = writer
writer = null
w.commitAndClose()
- _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
- batchSizes.append(curWriteMetrics.shuffleBytesWritten)
+ _diskBytesSpilled += spillMetrics.shuffleBytesWritten
+ batchSizes.append(spillMetrics.shuffleBytesWritten)
+ spillMetrics = null
objectsWritten = 0
}
@@ -307,9 +287,7 @@ private[spark] class ExternalSorter[K, V, C](
if (objectsWritten == serializerBatchSize) {
flush()
- curWriteMetrics = new ShuffleWriteMetrics()
- writer = blockManager.getDiskWriter(
- blockId, file, serInstance, fileBufferSize, curWriteMetrics)
+ openWriter()
}
}
if (objectsWritten > 0) {
@@ -337,46 +315,6 @@ private[spark] class ExternalSorter[K, V, C](
}
/**
- * Spill our in-memory collection to separate files, one for each partition. This is used when
- * there's no aggregator and ordering and the number of partitions is small, because it allows
- * writePartitionedFile to just concatenate files without deserializing data.
- *
- * @param collection whichever collection we're using (map or buffer)
- */
- private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = {
- spillToPartitionFiles(collection.writablePartitionedIterator())
- }
-
- private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = {
- assert(bypassMergeSort)
-
- // Create our file writers if we haven't done so yet
- if (partitionWriters == null) {
- curWriteMetrics = new ShuffleWriteMetrics()
- val openStartTime = System.nanoTime
- partitionWriters = Array.fill(numPartitions) {
- // Because these files may be read during shuffle, their compression must be controlled by
- // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
- // createTempShuffleBlock here; see SPARK-3426 for more context.
- val (blockId, file) = diskBlockManager.createTempShuffleBlock()
- val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize,
- curWriteMetrics)
- writer.open()
- }
- // Creating the file to write to and creating a disk writer both involve interacting with
- // the disk, and can take a long time in aggregate when we open many files, so should be
- // included in the shuffle write time.
- curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
- }
-
- // No need to sort stuff, just write each element out
- while (iterator.hasNext) {
- val partitionId = iterator.nextPartition()
- iterator.writeNext(partitionWriters(partitionId))
- }
- }
-
- /**
* Merge a sequence of sorted files, giving an iterator over partitions and then over elements
* inside each partition. This can be used to either write out a new file or return data to
* the user.
@@ -665,8 +603,6 @@ private[spark] class ExternalSorter[K, V, C](
}
/**
- * Exposed for testing purposes.
- *
* Return an iterator over all the data written to this object, grouped by partition and
* aggregated by the requested aggregator. For each partition we then have an iterator over its
* contents, and these are expected to be accessed in order (you can't "skip ahead" to one
@@ -676,10 +612,11 @@ private[spark] class ExternalSorter[K, V, C](
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
*/
- def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
+ @VisibleForTesting
+ def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
- if (spills.isEmpty && partitionWriters == null) {
+ if (spills.isEmpty) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
@@ -689,13 +626,6 @@ private[spark] class ExternalSorter[K, V, C](
// We do need to sort by both partition ID and key
groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
}
- } else if (bypassMergeSort) {
- // Read data from each partition file and merge it together with the data in memory;
- // note that there's no ordering or aggregator in this case -- we just partition objects
- val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None))
- collIter.map { case (partitionId, values) =>
- (partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
- }
} else {
// Merge spilled and in-memory data
merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
@@ -709,14 +639,13 @@ private[spark] class ExternalSorter[K, V, C](
/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
- * called by the SortShuffleWriter and can go through an efficient path of just concatenating
- * binary files if we decided to avoid merge-sorting.
+ * called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
- def writePartitionedFile(
+ override def writePartitionedFile(
blockId: BlockId,
context: TaskContext,
outputFile: File): Array[Long] = {
@@ -724,28 +653,7 @@ private[spark] class ExternalSorter[K, V, C](
// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
- if (bypassMergeSort && partitionWriters != null) {
- // We decided to write separate files for each partition, so just concatenate them. To keep
- // this simple we spill out the current in-memory collection so that everything is in files.
- spillToPartitionFiles(if (aggregator.isDefined) map else buffer)
- partitionWriters.foreach(_.commitAndClose())
- val out = new FileOutputStream(outputFile, true)
- val writeStartTime = System.nanoTime
- util.Utils.tryWithSafeFinally {
- for (i <- 0 until numPartitions) {
- val in = new FileInputStream(partitionWriters(i).fileSegment().file)
- util.Utils.tryWithSafeFinally {
- lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled)
- } {
- in.close()
- }
- }
- } {
- out.close()
- context.taskMetrics.shuffleWriteMetrics.foreach(
- _.incShuffleWriteTime(System.nanoTime - writeStartTime))
- }
- } else if (spills.isEmpty && partitionWriters == null) {
+ if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
@@ -761,7 +669,7 @@ private[spark] class ExternalSorter[K, V, C](
lengths(partitionId) = segment.length
}
} else {
- // Not bypassing merge-sort; get an iterator by partition and just write everything directly.
+ // We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
@@ -778,41 +686,15 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
- context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m =>
- if (curWriteMetrics != null) {
- m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten)
- m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime)
- m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten)
- }
- }
lengths
}
- /**
- * Read a partition file back as an iterator (used in our iterator method)
- */
- private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = {
- if (writer.isOpen) {
- writer.commitAndClose()
- }
- new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get)
- }
-
def stop(): Unit = {
spills.foreach(s => s.file.delete())
spills.clear()
- if (partitionWriters != null) {
- partitionWriters.foreach { w =>
- w.revertPartialWritesAndClose()
- diskBlockManager.getFile(w.blockId).delete()
- }
- partitionWriters = null
- }
}
- def diskBytesSpilled: Long = _diskBytesSpilled
-
/**
* Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
* group together the pairs for each partition into a sub-iterator.
@@ -826,14 +708,6 @@ private[spark] class ExternalSorter[K, V, C](
(0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
}
- private def comparator: Option[Comparator[K]] = {
- if (ordering.isDefined || aggregator.isDefined) {
- Some(keyComparator)
- } else {
- None
- }
- }
-
/**
* An iterator that reads only the elements for a given partition ID from an underlying buffered
* stream, assuming this partition is the next one to be read. Used to make it easier to return
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
deleted file mode 100644
index d75959f480..0000000000
--- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
+++ /dev/null
@@ -1,24 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
- def hasNext: Boolean = iter.hasNext
-
- def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
-}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
index e2e2f1faae..d0d25b43d0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
@@ -34,10 +34,6 @@ private[spark] class PartitionedAppendOnlyMap[K, V]
destructiveSortedIterator(comparator)
}
- def writablePartitionedIterator(): WritablePartitionedIterator = {
- WritablePartitionedIterator.fromIterator(super.iterator)
- }
-
def insert(partition: Int, key: K, value: V): Unit = {
update((partition, key), value)
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index e8332e1a87..5a6e9a9580 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -71,10 +71,6 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
iterator
}
- override def writablePartitionedIterator(): WritablePartitionedIterator = {
- WritablePartitionedIterator.fromIterator(iterator)
- }
-
private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
var pos = 0
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
index 554d88206e..862408b7a4 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -122,10 +122,6 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
sort(keyComparator)
- writablePartitionedIterator
- }
-
- override def writablePartitionedIterator(): WritablePartitionedIterator = {
new WritablePartitionedIterator {
// current position in the meta buffer in ints
var pos = 0
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index f26d1618c9..7bc5989865 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -47,13 +47,20 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
*/
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
- WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator))
- }
+ val it = partitionedDestructiveSortedIterator(keyComparator)
+ new WritablePartitionedIterator {
+ private[this] var cur = if (it.hasNext) it.next() else null
- /**
- * Iterate through the data and write out the elements instead of returning them.
- */
- def writablePartitionedIterator(): WritablePartitionedIterator
+ def writeNext(writer: BlockObjectWriter): Unit = {
+ writer.write(cur._1._2, cur._2)
+ cur = if (it.hasNext) it.next() else null
+ }
+
+ def hasNext(): Boolean = cur != null
+
+ def nextPartition(): Int = cur._1._1
+ }
+ }
}
private[spark] object WritablePartitionedPairCollection {
@@ -94,20 +101,3 @@ private[spark] trait WritablePartitionedIterator {
def nextPartition(): Int
}
-
-private[spark] object WritablePartitionedIterator {
- def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = {
- new WritablePartitionedIterator {
- var cur = if (it.hasNext) it.next() else null
-
- def writeNext(writer: BlockObjectWriter): Unit = {
- writer.write(cur._1._2, cur._2)
- cur = if (it.hasNext) it.next() else null
- }
-
- def hasNext(): Boolean = cur != null
-
- def nextPartition(): Int = cur._1._1
- }
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 91f4ab3608..c3c2b1ffc1 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.Matchers
import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
import org.apache.spark.util.MutablePair
@@ -281,6 +282,39 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// This count should retry the execution of the previous stage and rerun shuffle.
rdd.count()
}
+
+ test("metrics for shuffle without aggregation") {
+ sc = new SparkContext("local", "test", conf.clone())
+ val numRecords = 10000
+
+ val metrics = ShuffleSuite.runAndReturnMetrics(sc) {
+ sc.parallelize(1 to numRecords, 4)
+ .map(key => (key, 1))
+ .groupByKey()
+ .collect()
+ }
+
+ assert(metrics.recordsRead === numRecords)
+ assert(metrics.recordsWritten === numRecords)
+ assert(metrics.bytesWritten === metrics.byresRead)
+ assert(metrics.bytesWritten > 0)
+ }
+
+ test("metrics for shuffle with aggregation") {
+ sc = new SparkContext("local", "test", conf.clone())
+ val numRecords = 10000
+
+ val metrics = ShuffleSuite.runAndReturnMetrics(sc) {
+ sc.parallelize(1 to numRecords, 4)
+ .flatMap(key => Array.fill(100)((key, 1)))
+ .countByKey()
+ }
+
+ assert(metrics.recordsRead === numRecords)
+ assert(metrics.recordsWritten === numRecords)
+ assert(metrics.bytesWritten === metrics.byresRead)
+ assert(metrics.bytesWritten > 0)
+ }
}
object ShuffleSuite {
@@ -294,4 +328,35 @@ object ShuffleSuite {
value - o.value
}
}
+
+ case class AggregatedShuffleMetrics(
+ recordsWritten: Long,
+ recordsRead: Long,
+ bytesWritten: Long,
+ byresRead: Long)
+
+ def runAndReturnMetrics(sc: SparkContext)(job: => Unit): AggregatedShuffleMetrics = {
+ @volatile var recordsWritten: Long = 0
+ @volatile var recordsRead: Long = 0
+ @volatile var bytesWritten: Long = 0
+ @volatile var bytesRead: Long = 0
+ val listener = new SparkListener {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m =>
+ recordsWritten += m.shuffleRecordsWritten
+ bytesWritten += m.shuffleBytesWritten
+ }
+ taskEnd.taskMetrics.shuffleReadMetrics.foreach { m =>
+ recordsRead += m.recordsRead
+ bytesRead += m.totalBytesRead
+ }
+ }
+ }
+ sc.addSparkListener(listener)
+
+ job
+
+ sc.listenerBus.waitUntilEmpty(500)
+ AggregatedShuffleMetrics(recordsWritten, recordsRead, bytesWritten, bytesRead)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 19f1af0dcd..9e4d34fb7d 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -193,26 +193,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
assert(records == numRecords)
}
- test("shuffle records read metrics") {
- val recordsRead = runAndReturnShuffleRecordsRead {
- sc.textFile(tmpFilePath, 4)
- .map(key => (key, 1))
- .groupByKey()
- .collect()
- }
- assert(recordsRead == numRecords)
- }
-
- test("shuffle records written metrics") {
- val recordsWritten = runAndReturnShuffleRecordsWritten {
- sc.textFile(tmpFilePath, 4)
- .map(key => (key, 1))
- .groupByKey()
- .collect()
- }
- assert(recordsWritten == numRecords)
- }
-
/**
* Tests the metrics from end to end.
* 1) reading a hadoop file
@@ -301,14 +281,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten))
}
- private def runAndReturnShuffleRecordsRead(job: => Unit): Long = {
- runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead))
- }
-
- private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = {
- runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten))
- }
-
private def runAndReturnMetrics(job: => Unit,
collector: (SparkListenerTaskEnd) => Option[Long]): Long = {
val taskMetrics = new ArrayBuffer[Long]()
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
new file mode 100644
index 0000000000..c8420db612
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.File
+import java.util.UUID
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+import org.apache.spark._
+import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
+import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
+import org.apache.spark.storage._
+import org.apache.spark.util.Utils
+
+class BypassMergeSortShuffleWriterSuite extends FunSuite with BeforeAndAfterEach {
+
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+
+ private var taskMetrics: TaskMetrics = _
+ private var shuffleWriteMetrics: ShuffleWriteMetrics = _
+ private var tempDir: File = _
+ private var outputFile: File = _
+ private val conf: SparkConf = new SparkConf(loadDefaults = false)
+ private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
+ private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
+ private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
+ private val serializer: Serializer = new JavaSerializer(conf)
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ outputFile = File.createTempFile("shuffle", null, tempDir)
+ shuffleWriteMetrics = new ShuffleWriteMetrics
+ taskMetrics = new TaskMetrics
+ taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
+ MockitoAnnotations.initMocks(this)
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+ when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+ when(blockManager.getDiskWriter(
+ any[BlockId],
+ any[File],
+ any[SerializerInstance],
+ anyInt(),
+ any[ShuffleWriteMetrics]
+ )).thenAnswer(new Answer[BlockObjectWriter] {
+ override def answer(invocation: InvocationOnMock): BlockObjectWriter = {
+ val args = invocation.getArguments
+ new DiskBlockObjectWriter(
+ args(0).asInstanceOf[BlockId],
+ args(1).asInstanceOf[File],
+ args(2).asInstanceOf[SerializerInstance],
+ args(3).asInstanceOf[Int],
+ compressStream = identity,
+ syncWrites = false,
+ args(4).asInstanceOf[ShuffleWriteMetrics]
+ )
+ }
+ })
+ when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+ new Answer[(TempShuffleBlockId, File)] {
+ override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = {
+ val blockId = new TempShuffleBlockId(UUID.randomUUID)
+ val file = File.createTempFile(blockId.toString, null, tempDir)
+ blockIdToFileMap.put(blockId, file)
+ temporaryFilesCreated.append(file)
+ (blockId, file)
+ }
+ })
+ when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
+ new Answer[File] {
+ override def answer(invocation: InvocationOnMock): File = {
+ blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get
+ }
+ })
+ }
+
+ override def afterEach(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ blockIdToFileMap.clear()
+ temporaryFilesCreated.clear()
+ }
+
+ test("write empty iterator") {
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ new SparkConf(loadDefaults = false),
+ blockManager,
+ new HashPartitioner(7),
+ shuffleWriteMetrics,
+ serializer
+ )
+ writer.insertAll(Iterator.empty)
+ val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
+ assert(partitionLengths.sum === 0)
+ assert(outputFile.exists())
+ assert(outputFile.length() === 0)
+ assert(temporaryFilesCreated.isEmpty)
+ assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
+ assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
+
+ test("write with some empty partitions") {
+ def records: Iterator[(Int, Int)] =
+ Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ new SparkConf(loadDefaults = false),
+ blockManager,
+ new HashPartitioner(7),
+ shuffleWriteMetrics,
+ serializer
+ )
+ writer.insertAll(records)
+ assert(temporaryFilesCreated.nonEmpty)
+ val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
+ assert(partitionLengths.sum === outputFile.length())
+ assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+ assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
+ assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
+
+ test("cleanup of intermediate files after errors") {
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ new SparkConf(loadDefaults = false),
+ blockManager,
+ new HashPartitioner(7),
+ shuffleWriteMetrics,
+ serializer
+ )
+ intercept[SparkException] {
+ writer.insertAll((0 until 100000).iterator.map(i => {
+ if (i == 99990) {
+ throw new SparkException("Intentional failure")
+ }
+ (i, i)
+ }))
+ }
+ assert(temporaryFilesCreated.nonEmpty)
+ writer.stop()
+ assert(temporaryFilesCreated.count(_.exists()) === 0)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
new file mode 100644
index 0000000000..c6ada7139c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import org.mockito.Mockito._
+import org.scalatest.FunSuite
+
+import org.apache.spark.{Aggregator, SparkConf}
+
+class SortShuffleWriterSuite extends FunSuite {
+
+ import SortShuffleWriter._
+
+ test("conditions for bypassing merge-sort") {
+ val conf = new SparkConf(loadDefaults = false)
+ val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
+ val ord = implicitly[Ordering[Int]]
+
+ // Numbers of partitions that are above and below the default bypassMergeThreshold
+ val FEW_PARTITIONS = 50
+ val MANY_PARTITIONS = 10000
+
+ // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
+ assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
+ assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
+
+ // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
+ assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
+ assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
index ad43a3e5fd..7bdea724fe 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -18,14 +18,28 @@ package org.apache.spark.storage
import java.io.File
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
-class BlockObjectWriterSuite extends SparkFunSuite {
+class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+ var tempDir: File = _
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ }
+
+ override def afterEach(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ }
+
test("verify write metrics") {
- val file = new File(Utils.createTempDir(), "somefile")
+ val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -47,7 +61,7 @@ class BlockObjectWriterSuite extends SparkFunSuite {
}
test("verify write metrics on revert") {
- val file = new File(Utils.createTempDir(), "somefile")
+ val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -70,7 +84,7 @@ class BlockObjectWriterSuite extends SparkFunSuite {
}
test("Reopening a closed block writer") {
- val file = new File(Utils.createTempDir(), "somefile")
+ val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -81,4 +95,79 @@ class BlockObjectWriterSuite extends SparkFunSuite {
writer.open()
}
}
+
+ test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ for (i <- 1 to 1000) {
+ writer.write(i, i)
+ }
+ writer.commitAndClose()
+ val bytesWritten = writeMetrics.shuffleBytesWritten
+ assert(writeMetrics.shuffleRecordsWritten === 1000)
+ writer.revertPartialWritesAndClose()
+ assert(writeMetrics.shuffleRecordsWritten === 1000)
+ assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+ }
+
+ test("commitAndClose() should be idempotent") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ for (i <- 1 to 1000) {
+ writer.write(i, i)
+ }
+ writer.commitAndClose()
+ val bytesWritten = writeMetrics.shuffleBytesWritten
+ val writeTime = writeMetrics.shuffleWriteTime
+ assert(writeMetrics.shuffleRecordsWritten === 1000)
+ writer.commitAndClose()
+ assert(writeMetrics.shuffleRecordsWritten === 1000)
+ assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+ assert(writeMetrics.shuffleWriteTime === writeTime)
+ }
+
+ test("revertPartialWritesAndClose() should be idempotent") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ for (i <- 1 to 1000) {
+ writer.write(i, i)
+ }
+ writer.revertPartialWritesAndClose()
+ val bytesWritten = writeMetrics.shuffleBytesWritten
+ val writeTime = writeMetrics.shuffleWriteTime
+ assert(writeMetrics.shuffleRecordsWritten === 0)
+ writer.revertPartialWritesAndClose()
+ assert(writeMetrics.shuffleRecordsWritten === 0)
+ assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+ assert(writeMetrics.shuffleWriteTime === writeTime)
+ }
+
+ test("fileSegment() can only be called after commitAndClose() has been called") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ for (i <- 1 to 1000) {
+ writer.write(i, i)
+ }
+ intercept[IllegalStateException] {
+ writer.fileSegment()
+ }
+ writer.close()
+ }
+
+ test("commitAndClose() without ever opening or writing") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ writer.commitAndClose()
+ assert(writer.fileSegment().length === 0)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 9039dbef1f..7d7b41bc23 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -23,10 +23,12 @@ import org.scalatest.PrivateMethodTester
import scala.util.Random
+import org.scalatest.FunSuite
+
import org.apache.spark._
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester {
+class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = {
val conf = new SparkConf(loadDefaults)
if (kryo) {
@@ -37,21 +39,12 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
conf.set("spark.serializer.objectStreamReset", "1")
conf.set("spark.serializer", classOf[JavaSerializer].getName)
}
+ conf.set("spark.shuffle.sort.bypassMergeThreshold", "0")
// Ensure that we actually have multiple batches per spill file
conf.set("spark.shuffle.spill.batchSize", "10")
conf
}
- private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
- val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
- assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort")
- }
-
- private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
- val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
- assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort")
- }
-
test("empty data stream with kryo ser") {
emptyDataStream(createSparkConf(false, true))
}
@@ -161,39 +154,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
val sorter = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(7)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter)
- sorter.insertAll(elements)
- assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
- val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
- assert(iter.next() === (0, Nil))
- assert(iter.next() === (1, List((1, 1))))
- assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
- assert(iter.next() === (3, Nil))
- assert(iter.next() === (4, Nil))
- assert(iter.next() === (5, List((5, 5))))
- assert(iter.next() === (6, Nil))
- sorter.stop()
- }
-
- test("empty partitions with spilling, bypass merge-sort with kryo ser") {
- emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true))
- }
-
- test("empty partitions with spilling, bypass merge-sort with java ser") {
- emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false))
- }
-
- def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) {
- conf.set("spark.shuffle.memoryFraction", "0.001")
- conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
- conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local", "test", conf)
-
- val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
-
- val sorter = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(7)), None, None)
- assertBypassedMergeSort(sorter)
sorter.insertAll(elements)
assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -376,7 +336,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
val sorter = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(3)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter)
sorter.insertAll((0 until 120000).iterator.map(i => (i, i)))
assert(diskBlockManager.getAllFiles().length > 0)
sorter.stop()
@@ -384,7 +343,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
val sorter2 = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(3)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter2)
sorter2.insertAll((0 until 120000).iterator.map(i => (i, i)))
assert(diskBlockManager.getAllFiles().length > 0)
assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet)
@@ -392,29 +350,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
assert(diskBlockManager.getAllBlocks().length === 0)
}
- test("cleanup of intermediate files in sorter, bypass merge-sort") {
- val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
- conf.set("spark.shuffle.memoryFraction", "0.001")
- conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local", "test", conf)
- val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
- val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
- assertBypassedMergeSort(sorter)
- sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
- assert(diskBlockManager.getAllFiles().length > 0)
- sorter.stop()
- assert(diskBlockManager.getAllBlocks().length === 0)
-
- val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
- assertBypassedMergeSort(sorter2)
- sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
- assert(diskBlockManager.getAllFiles().length > 0)
- assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
- sorter2.stop()
- assert(diskBlockManager.getAllBlocks().length === 0)
- }
-
test("cleanup of intermediate files in sorter if there are errors") {
val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -426,7 +361,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
val sorter = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(3)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter)
intercept[SparkException] {
sorter.insertAll((0 until 120000).iterator.map(i => {
if (i == 119990) {
@@ -440,28 +374,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
assert(diskBlockManager.getAllBlocks().length === 0)
}
- test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") {
- val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
- conf.set("spark.shuffle.memoryFraction", "0.001")
- conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local", "test", conf)
- val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
- val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
- assertBypassedMergeSort(sorter)
- intercept[SparkException] {
- sorter.insertAll((0 until 100000).iterator.map(i => {
- if (i == 99990) {
- throw new SparkException("Intentional failure")
- }
- (i, i)
- }))
- }
- assert(diskBlockManager.getAllFiles().length > 0)
- sorter.stop()
- assert(diskBlockManager.getAllBlocks().length === 0)
- }
-
test("cleanup of intermediate files in shuffle") {
val conf = createSparkConf(false, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -776,40 +688,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
}
}
- test("conditions for bypassing merge-sort") {
- val conf = createSparkConf(false, false)
- conf.set("spark.shuffle.memoryFraction", "0.001")
- conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local", "test", conf)
-
- val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
- val ord = implicitly[Ordering[Int]]
-
- // Numbers of partitions that are above and below the default bypassMergeThreshold
- val FEW_PARTITIONS = 50
- val MANY_PARTITIONS = 10000
-
- // Sorters with no ordering or aggregator: should bypass unless # of partitions is high
-
- val sorter1 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
- assertBypassedMergeSort(sorter1)
-
- val sorter2 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None)
- assertDidNotBypassMergeSort(sorter2)
-
- // Sorters with an ordering or aggregator: should not bypass even if they have few partitions
-
- val sorter3 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter3)
-
- val sorter4 = new ExternalSorter[Int, Int, Int](
- Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
- assertDidNotBypassMergeSort(sorter4)
- }
-
test("sort without breaking sorting contracts with kryo ser") {
sortWithoutBreakingSortingContracts(createSparkConf(true, true))
}