aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java184
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java53
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala260
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala36
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala65
-rw-r--r--core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala171
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala46
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala97
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala130
17 files changed, 738 insertions, 423 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
new file mode 100644
index 0000000000..d3d6280284
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.Partitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
+import org.apache.spark.util.Utils;
+
+/**
+ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
+ * writes incoming records to separate files, one file per reduce partition, then concatenates these
+ * per-partition files to form a single output file, regions of which are served to reducers.
+ * Records are not buffered in memory. This is essentially identical to
+ * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format
+ * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}.
+ * <p>
+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it
+ * simultaneously opens separate serializers and file streams for all partitions. As a result,
+ * {@link SortShuffleManager} only selects this write path when
+ * <ul>
+ * <li>no Ordering is specified,</li>
+ * <li>no Aggregator is specific, and</li>
+ * <li>the number of partitions is less than
+ * <code>spark.shuffle.sort.bypassMergeThreshold</code>.</li>
+ * </ul>
+ *
+ * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was
+ * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details.
+ * <p>
+ * There have been proposals to completely remove this code path; see SPARK-6026 for details.
+ */
+final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
+
+ private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
+
+ private final int fileBufferSize;
+ private final boolean transferToEnabled;
+ private final int numPartitions;
+ private final BlockManager blockManager;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final Serializer serializer;
+
+ /** Array of file writers, one for each partition */
+ private BlockObjectWriter[] partitionWriters;
+
+ public BypassMergeSortShuffleWriter(
+ SparkConf conf,
+ BlockManager blockManager,
+ Partitioner partitioner,
+ ShuffleWriteMetrics writeMetrics,
+ Serializer serializer) {
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
+ this.numPartitions = partitioner.numPartitions();
+ this.blockManager = blockManager;
+ this.partitioner = partitioner;
+ this.writeMetrics = writeMetrics;
+ this.serializer = serializer;
+ }
+
+ @Override
+ public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
+ assert (partitionWriters == null);
+ if (!records.hasNext()) {
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new BlockObjectWriter[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open();
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime);
+
+ while (records.hasNext()) {
+ final Product2<K, V> record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
+ }
+
+ for (BlockObjectWriter writer : partitionWriters) {
+ writer.commitAndClose();
+ }
+ }
+
+ @Override
+ public long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException {
+ // Track location of the partition starts in the output file
+ final long[] lengths = new long[numPartitions];
+ if (partitionWriters == null) {
+ // We were passed an empty iterator
+ return lengths;
+ }
+
+ final FileOutputStream out = new FileOutputStream(outputFile, true);
+ final long writeStartTime = System.nanoTime();
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < numPartitions; i++) {
+ final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file());
+ boolean copyThrewException = true;
+ try {
+ lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) {
+ logger.error("Unable to delete file for partition {}", i);
+ }
+ }
+ threwException = false;
+ } finally {
+ Closeables.close(out, threwException);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+ }
+ partitionWriters = null;
+ return lengths;
+ }
+
+ @Override
+ public void stop() throws IOException {
+ if (partitionWriters != null) {
+ try {
+ final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
+ for (BlockObjectWriter writer : partitionWriters) {
+ // This method explicitly does _not_ throw exceptions:
+ writer.revertPartialWritesAndClose();
+ if (!diskBlockManager.getFile(writer.blockId()).delete()) {
+ logger.error("Error while deleting file for block {}", writer.blockId());
+ }
+ }
+ } finally {
+ partitionWriters = null;
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
new file mode 100644
index 0000000000..656ea0401a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.collection.Iterator;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.TaskContext;
+import org.apache.spark.storage.BlockId;
+
+/**
+ * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
+ */
+@Private
+public interface SortShuffleFileWriter<K, V> {
+
+ void insertAll(Iterator<Product2<K, V>> records) throws IOException;
+
+ /**
+ * Write all the data added into this shuffle sorter into a file in the disk store. This is
+ * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+ * binary files if we decided to avoid merge-sorting.
+ *
+ * @param blockId block ID to write to. The index file will be blockId.name + ".index".
+ * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+ */
+ long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException;
+
+ void stop() throws IOException;
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index c9dd6bfc4c..5865e7640c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -17,9 +17,10 @@
package org.apache.spark.shuffle.sort
-import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
+import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
@@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
- private var sorter: ExternalSorter[K, V, _] = null
+ private var sorter: SortShuffleFileWriter[K, V] = null
// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
@@ -49,18 +50,27 @@ private[spark] class SortShuffleWriter[K, V, C](
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
- if (dep.mapSideCombine) {
+ sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
- sorter = new ExternalSorter[K, V, C](
+ new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
- sorter.insertAll(records)
+ } else if (SortShuffleWriter.shouldBypassMergeSort(
+ SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
+ // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+ // need local aggregation and sorting, write numPartitions files directly and just concatenate
+ // them at the end. This avoids doing serialization and deserialization twice to merge
+ // together the spilled files, which would happen with the normal code path. The downside is
+ // having multiple files open at a time and thus more memory allocated to buffers.
+ new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
+ writeMetrics, Serializer.getSerializer(dep.serializer))
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
- sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer)
- sorter.insertAll(records)
+ new ExternalSorter[K, V, V](
+ aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
+ sorter.insertAll(records)
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
@@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C](
}
}
+private[spark] object SortShuffleWriter {
+ def shouldBypassMergeSort(
+ conf: SparkConf,
+ numPartitions: Int,
+ aggregator: Option[Aggregator[_, _, _]],
+ keyOrdering: Option[Ordering[_]]): Boolean = {
+ val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index a33f22ef52..7eeabd1e04 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter(
private var objOut: SerializationStream = null
private var initialized = false
private var hasBeenClosed = false
+ private var commitAndCloseHasBeenCalled = false
/**
* Cursors used to represent positions in the file.
@@ -167,20 +168,22 @@ private[spark] class DiskBlockObjectWriter(
objOut.flush()
bs.flush()
close()
+ finalPosition = file.length()
+ // In certain compression codecs, more bytes are written after close() is called
+ writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+ } else {
+ finalPosition = file.length()
}
- finalPosition = file.length()
- // In certain compression codecs, more bytes are written after close() is called
- writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+ commitAndCloseHasBeenCalled = true
}
// Discard current writes. We do this by flushing the outstanding writes and then
// truncating the file to its initial position.
override def revertPartialWritesAndClose() {
try {
- writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
- writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
-
if (initialized) {
+ writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
+ writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
objOut.flush()
bs.flush()
close()
@@ -228,6 +231,10 @@ private[spark] class DiskBlockObjectWriter(
}
override def fileSegment(): FileSegment = {
+ if (!commitAndCloseHasBeenCalled) {
+ throw new IllegalStateException(
+ "fileSegment() is only valid after commitAndClose() has been called")
+ }
new FileSegment(file, initialPosition, finalPosition - initialPosition)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
index 95e2d688d9..021a9facfb 100644
--- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -24,6 +24,8 @@ import java.io.File
* based off an offset and a length.
*/
private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) {
+ require(offset >= 0, s"File segment offset cannot be negative (got $offset)")
+ require(length >= 0, s"File segment length cannot be negative (got $length)")
override def toString: String = {
"(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3b9d14f937..ef2dbb7ff0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -23,12 +23,14 @@ import java.util.Comparator
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
+import com.google.common.annotations.VisibleForTesting
import com.google.common.io.ByteStreams
import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.storage.{BlockObjectWriter, BlockId}
+import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
+import org.apache.spark.storage.{BlockId, BlockObjectWriter}
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -84,35 +86,40 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
* each other for equality to merge values.
*
* - Users are expected to call stop() at the end to delete all the intermediate files.
- *
- * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is
- * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to
- * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can
- * then concatenate these files to produce a single sorted file, without having to serialize and
- * de-serialize each item twice (as is needed during the merge). This speeds up the map side of
- * groupBy, sort, etc operations since they do no partial aggregation.
*/
private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
- extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] {
+ extends Logging
+ with Spillable[WritablePartitionedPairCollection[K, C]]
+ with SortShuffleFileWriter[K, V] {
+
+ private val conf = SparkEnv.get.conf
private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
private val shouldPartition = numPartitions > 1
+ private def getPartition(key: K): Int = {
+ if (shouldPartition) partitioner.get.getPartition(key) else 0
+ }
+
+ // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
+ // As a sanity check, make sure that we're not handling a shuffle which should use that path.
+ if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
+ throw new IllegalArgumentException("ExternalSorter should not be used to handle "
+ + " a sort that the BypassMergeSortShuffleWriter should handle")
+ }
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(serializer)
private val serInstance = ser.newInstance()
- private val conf = SparkEnv.get.conf
private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
- private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
// Size of object batches when reading/writing from serializers.
//
@@ -123,43 +130,28 @@ private[spark] class ExternalSorter[K, V, C](
// grow internal data structures by growing + copying every time the number of objects doubles.
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
- private def getPartition(key: K): Int = {
- if (shouldPartition) partitioner.get.getPartition(key) else 0
- }
-
- private val metaInitialRecords = 256
- private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
private val useSerializedPairBuffer =
- !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
- ser.supportsRelocationOfSerializedObjects
-
+ ordering.isEmpty &&
+ conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
+ ser.supportsRelocationOfSerializedObjects
+ private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
+ private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
+ if (useSerializedPairBuffer) {
+ new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
+ } else {
+ new PartitionedPairBuffer[K, C]
+ }
+ }
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
private var map = new PartitionedAppendOnlyMap[K, C]
- private var buffer = if (useSerializedPairBuffer) {
- new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
- } else {
- new PartitionedPairBuffer[K, C]
- }
+ private var buffer = newBuffer()
// Total spilling statistics
private var _diskBytesSpilled = 0L
+ def diskBytesSpilled: Long = _diskBytesSpilled
- // Write metrics for current spill
- private var curWriteMetrics: ShuffleWriteMetrics = _
-
- // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
- // local aggregation and sorting, write numPartitions files directly and just concatenate them
- // at the end. This avoids doing serialization and deserialization twice to merge together the
- // spilled files, which would happen with the normal code path. The downside is having multiple
- // files open at a time and thus more memory allocated to buffers.
- private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- private val bypassMergeSort =
- (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty)
-
- // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled
- private var partitionWriters: Array[BlockObjectWriter] = null
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -174,6 +166,14 @@ private[spark] class ExternalSorter[K, V, C](
}
})
+ private def comparator: Option[Comparator[K]] = {
+ if (ordering.isDefined || aggregator.isDefined) {
+ Some(keyComparator)
+ } else {
+ None
+ }
+ }
+
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
@@ -182,9 +182,10 @@ private[spark] class ExternalSorter[K, V, C](
blockId: BlockId,
serializerBatchSizes: Array[Long],
elementsPerPartition: Array[Long])
+
private val spills = new ArrayBuffer[SpilledFile]
- def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
@@ -202,15 +203,6 @@ private[spark] class ExternalSorter[K, V, C](
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
- } else if (bypassMergeSort) {
- // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
- if (records.hasNext) {
- spillToPartitionFiles(
- WritablePartitionedIterator.fromIterator(records.map { kv =>
- ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
- })
- )
- }
} else {
// Stick values into our buffer
while (records.hasNext) {
@@ -238,46 +230,33 @@ private[spark] class ExternalSorter[K, V, C](
}
} else {
if (maybeSpill(buffer, buffer.estimateSize())) {
- buffer = if (useSerializedPairBuffer) {
- new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
- } else {
- new PartitionedPairBuffer[K, C]
- }
+ buffer = newBuffer()
}
}
}
/**
- * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
- */
- override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
- if (bypassMergeSort) {
- spillToPartitionFiles(collection)
- } else {
- spillToMergeableFile(collection)
- }
- }
-
- /**
- * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
- * We add this file into spilledFiles to find it later.
- *
- * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles()
- * is used to write files for each partition.
+ * Spill our in-memory collection to a sorted file that we can merge later.
+ * We add this file into `spilledFiles` to find it later.
*
* @param collection whichever collection we're using (map or buffer)
*/
- private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = {
- assert(!bypassMergeSort)
-
+ override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
- curWriteMetrics = new ShuffleWriteMetrics()
- var writer = blockManager.getDiskWriter(
- blockId, file, serInstance, fileBufferSize, curWriteMetrics)
- var objectsWritten = 0 // Objects written since the last flush
+
+ // These variables are reset after each flush
+ var objectsWritten: Long = 0
+ var spillMetrics: ShuffleWriteMetrics = null
+ var writer: BlockObjectWriter = null
+ def openWriter(): Unit = {
+ assert (writer == null && spillMetrics == null)
+ spillMetrics = new ShuffleWriteMetrics
+ writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
+ }
+ openWriter()
// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]
@@ -291,8 +270,9 @@ private[spark] class ExternalSorter[K, V, C](
val w = writer
writer = null
w.commitAndClose()
- _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
- batchSizes.append(curWriteMetrics.shuffleBytesWritten)
+ _diskBytesSpilled += spillMetrics.shuffleBytesWritten
+ batchSizes.append(spillMetrics.shuffleBytesWritten)
+ spillMetrics = null
objectsWritten = 0
}
@@ -307,9 +287,7 @@ private[spark] class ExternalSorter[K, V, C](
if (objectsWritten == serializerBatchSize) {
flush()
- curWriteMetrics = new ShuffleWriteMetrics()
- writer = blockManager.getDiskWriter(
- blockId, file, serInstance, fileBufferSize, curWriteMetrics)
+ openWriter()
}
}
if (objectsWritten > 0) {
@@ -337,46 +315,6 @@ private[spark] class ExternalSorter[K, V, C](
}
/**
- * Spill our in-memory collection to separate files, one for each partition. This is used when
- * there's no aggregator and ordering and the number of partitions is small, because it allows
- * writePartitionedFile to just concatenate files without deserializing data.
- *
- * @param collection whichever collection we're using (map or buffer)
- */
- private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = {
- spillToPartitionFiles(collection.writablePartitionedIterator())
- }
-
- private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = {
- assert(bypassMergeSort)
-
- // Create our file writers if we haven't done so yet
- if (partitionWriters == null) {
- curWriteMetrics = new ShuffleWriteMetrics()
- val openStartTime = System.nanoTime
- partitionWriters = Array.fill(numPartitions) {
- // Because these files may be read during shuffle, their compression must be controlled by
- // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
- // createTempShuffleBlock here; see SPARK-3426 for more context.
- val (blockId, file) = diskBlockManager.createTempShuffleBlock()
- val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize,
- curWriteMetrics)
- writer.open()
- }
- // Creating the file to write to and creating a disk writer both involve interacting with
- // the disk, and can take a long time in aggregate when we open many files, so should be
- // included in the shuffle write time.
- curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
- }
-
- // No need to sort stuff, just write each element out
- while (iterator.hasNext) {
- val partitionId = iterator.nextPartition()
- iterator.writeNext(partitionWriters(partitionId))
- }
- }
-
- /**
* Merge a sequence of sorted files, giving an iterator over partitions and then over elements
* inside each partition. This can be used to either write out a new file or return data to
* the user.
@@ -665,8 +603,6 @@ private[spark] class ExternalSorter[K, V, C](
}
/**
- * Exposed for testing purposes.
- *
* Return an iterator over all the data written to this object, grouped by partition and
* aggregated by the requested aggregator. For each partition we then have an iterator over its
* contents, and these are expected to be accessed in order (you can't "skip ahead" to one
@@ -676,10 +612,11 @@ private[spark] class ExternalSorter[K, V, C](
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
*/
- def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
+ @VisibleForTesting
+ def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
- if (spills.isEmpty && partitionWriters == null) {
+ if (spills.isEmpty) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
@@ -689,13 +626,6 @@ private[spark] class ExternalSorter[K, V, C](
// We do need to sort by both partition ID and key
groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
}
- } else if (bypassMergeSort) {
- // Read data from each partition file and merge it together with the data in memory;
- // note that there's no ordering or aggregator in this case -- we just partition objects
- val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None))
- collIter.map { case (partitionId, values) =>
- (partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
- }
} else {
// Merge spilled and in-memory data
merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
@@ -709,14 +639,13 @@ private[spark] class ExternalSorter[K, V, C](
/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
- * called by the SortShuffleWriter and can go through an efficient path of just concatenating
- * binary files if we decided to avoid merge-sorting.
+ * called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
- def writePartitionedFile(
+ override def writePartitionedFile(
blockId: BlockId,
context: TaskContext,
outputFile: File): Array[Long] = {
@@ -724,28 +653,7 @@ private[spark] class ExternalSorter[K, V, C](
// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
- if (bypassMergeSort && partitionWriters != null) {
- // We decided to write separate files for each partition, so just concatenate them. To keep
- // this simple we spill out the current in-memory collection so that everything is in files.
- spillToPartitionFiles(if (aggregator.isDefined) map else buffer)
- partitionWriters.foreach(_.commitAndClose())
- val out = new FileOutputStream(outputFile, true)
- val writeStartTime = System.nanoTime
- util.Utils.tryWithSafeFinally {
- for (i <- 0 until numPartitions) {
- val in = new FileInputStream(partitionWriters(i).fileSegment().file)
- util.Utils.tryWithSafeFinally {
- lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled)
- } {
- in.close()
- }
- }
- } {
- out.close()
- context.taskMetrics.shuffleWriteMetrics.foreach(
- _.incShuffleWriteTime(System.nanoTime - writeStartTime))
- }
- } else if (spills.isEmpty && partitionWriters == null) {
+ if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
@@ -761,7 +669,7 @@ private[spark] class ExternalSorter[K, V, C](
lengths(partitionId) = segment.length
}
} else {
- // Not bypassing merge-sort; get an iterator by partition and just write everything directly.
+ // We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
@@ -778,41 +686,15 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
- context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m =>
- if (curWriteMetrics != null) {
- m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten)
- m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime)
- m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten)
- }
- }
lengths
}
- /**
- * Read a partition file back as an iterator (used in our iterator method)
- */
- private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = {
- if (writer.isOpen) {
- writer.commitAndClose()
- }
- new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get)
- }
-
def stop(): Unit = {
spills.foreach(s => s.file.delete())
spills.clear()
- if (partitionWriters != null) {
- partitionWriters.foreach { w =>
- w.revertPartialWritesAndClose()
- diskBlockManager.getFile(w.blockId).delete()
- }
- partitionWriters = null
- }
}
- def diskBytesSpilled: Long = _diskBytesSpilled
-
/**
* Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
* group together the pairs for each partition into a sub-iterator.
@@ -826,14 +708,6 @@ private[spark] class ExternalSorter[K, V, C](
(0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
}
- private def comparator: Option[Comparator[K]] = {
- if (ordering.isDefined || aggregator.isDefined) {
- Some(keyComparator)
- } else {
- None
- }
- }
-
/**
* An iterator that reads only the elements for a given partition ID from an underlying buffered
* stream, assuming this partition is the next one to be read. Used to make it easier to return
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
deleted file mode 100644
index d75959f480..0000000000
--- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
+++ /dev/null
@@ -1,24 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
- def hasNext: Boolean = iter.hasNext
-
- def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
-}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
index e2e2f1faae..d0d25b43d0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
@@ -34,10 +34,6 @@ private[spark] class PartitionedAppendOnlyMap[K, V]
destructiveSortedIterator(comparator)
}
- def writablePartitionedIterator(): WritablePartitionedIterator = {
- WritablePartitionedIterator.fromIterator(super.iterator)
- }
-
def insert(partition: Int, key: K, value: V): Unit = {
update((partition, key), value)
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index e8332e1a87..5a6e9a9580 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -71,10 +71,6 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
iterator
}
- override def writablePartitionedIterator(): WritablePartitionedIterator = {
- WritablePartitionedIterator.fromIterator(iterator)
- }
-
private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
var pos = 0
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
index 554d88206e..862408b7a4 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -122,10 +122,6 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
sort(keyComparator)
- writablePartitionedIterator
- }
-
- override def writablePartitionedIterator(): WritablePartitionedIterator = {
new WritablePartitionedIterator {
// current position in the meta buffer in ints
var pos = 0
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index f26d1618c9..7bc5989865 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -47,13 +47,20 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
*/
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
- WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator))
- }
+ val it = partitionedDestructiveSortedIterator(keyComparator)
+ new WritablePartitionedIterator {
+ private[this] var cur = if (it.hasNext) it.next() else null
- /**
- * Iterate through the data and write out the elements instead of returning them.
- */
- def writablePartitionedIterator(): WritablePartitionedIterator
+ def writeNext(writer: BlockObjectWriter): Unit = {
+ writer.write(cur._1._2, cur._2)
+ cur = if (it.hasNext) it.next() else null
+ }
+
+ def hasNext(): Boolean = cur != null
+
+ def nextPartition(): Int = cur._1._1
+ }
+ }
}
private[spark] object WritablePartitionedPairCollection {
@@ -94,20 +101,3 @@ private[spark] trait WritablePartitionedIterator {
def nextPartition(): Int
}
-
-private[spark] object WritablePartitionedIterator {
- def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = {
- new WritablePartitionedIterator {
- var cur = if (it.hasNext) it.next() else null
-
- def writeNext(writer: BlockObjectWriter): Unit = {
- writer.write(cur._1._2, cur._2)
- cur = if (it.hasNext) it.next() else null
- }
-
- def hasNext(): Boolean = cur != null
-
- def nextPartition(): Int = cur._1._1
- }
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 91f4ab3608..c3c2b1ffc1 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.Matchers
import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
import org.apache.spark.util.MutablePair
@@ -281,6 +282,39 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// This count should retry the execution of the previous stage and rerun shuffle.
rdd.count()
}
+
+ test("metrics for shuffle without aggregation") {
+ sc = new SparkContext("local", "test", conf.clone())
+ val numRecords = 10000
+
+ val metrics = ShuffleSuite.runAndReturnMetrics(sc) {
+ sc.parallelize(1 to numRecords, 4)
+ .map(key => (key, 1))
+ .groupByKey()
+ .collect()
+ }
+
+ assert(metrics.recordsRead === numRecords)
+ assert(metrics.recordsWritten === numRecords)
+ assert(metrics.bytesWritten === metrics.byresRead)
+ assert(metrics.bytesWritten > 0)
+ }
+
+ test("metrics for shuffle with aggregation") {
+ sc = new SparkContext("local", "test", conf.clone())
+ val numRecords = 10000
+
+ val metrics = ShuffleSuite.runAndReturnMetrics(sc) {
+ sc.parallelize(1 to numRecords, 4)
+ .flatMap(key => Array.fill(100)((key, 1)))
+ .countByKey()
+ }
+
+ assert(metrics.recordsRead === numRecords)
+ assert(metrics.recordsWritten === numRecords)
+ assert(metrics.bytesWritten === metrics.byresRead)
+ assert(metrics.bytesWritten > 0)
+ }
}
object ShuffleSuite {
@@ -294,4 +328,35 @@ object ShuffleSuite {
value - o.value
}
}
+
+ case class AggregatedShuffleMetrics(
+ recordsWritten: Long,
+ recordsRead: Long,
+ bytesWritten: Long,
+ byresRead: Long)
+
+ def runAndReturnMetrics(sc: SparkContext)(job: => Unit): AggregatedShuffleMetrics = {
+ @volatile var recordsWritten: Long = 0
+ @volatile var recordsRead: Long = 0
+ @volatile var bytesWritten: Long = 0
+ @volatile var bytesRead: Long = 0
+ val listener = new SparkListener {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m =>
+ recordsWritten += m.shuffleRecordsWritten
+ bytesWritten += m.shuffleBytesWritten
+ }
+ taskEnd.taskMetrics.shuffleReadMetrics.foreach { m =>
+ recordsRead += m.recordsRead
+ bytesRead += m.totalBytesRead
+ }
+ }
+ }
+ sc.addSparkListener(listener)
+
+ job
+
+ sc.listenerBus.waitUntilEmpty(500)
+ AggregatedShuffleMetrics(recordsWritten, recordsRead, bytesWritten, bytesRead)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 19f1af0dcd..9e4d34fb7d 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -193,26 +193,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
assert(records == numRecords)
}
- test("shuffle records read metrics") {
- val recordsRead = runAndReturnShuffleRecordsRead {
- sc.textFile(tmpFilePath, 4)
- .map(key => (key, 1))
- .groupByKey()
- .collect()
- }
- assert(recordsRead == numRecords)
- }
-
- test("shuffle records written metrics") {
- val recordsWritten = runAndReturnShuffleRecordsWritten {
- sc.textFile(tmpFilePath, 4)
- .map(key => (key, 1))
- .groupByKey()
- .collect()
- }
- assert(recordsWritten == numRecords)
- }
-
/**
* Tests the metrics from end to end.
* 1) reading a hadoop file
@@ -301,14 +281,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten))
}
- private def runAndReturnShuffleRecordsRead(job: => Unit): Long = {
- runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead))
- }
-
- private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = {
- runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten))
- }
-
private def runAndReturnMetrics(job: => Unit,
collector: (SparkListenerTaskEnd) => Option[Long]): Long = {
val taskMetrics = new ArrayBuffer[Long]()
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
new file mode 100644
index 0000000000..c8420db612
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.File
+import java.util.UUID
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+import org.apache.spark._
+import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
+import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
+import org.apache.spark.storage._
+import org.apache.spark.util.Utils
+
+class BypassMergeSortShuffleWriterSuite extends FunSuite with BeforeAndAfterEach {
+
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+
+ private var taskMetrics: TaskMetrics = _
+ private var shuffleWriteMetrics: ShuffleWriteMetrics = _
+ private var tempDir: File = _
+ private var outputFile: File = _
+ private val conf: SparkConf = new SparkConf(loadDefaults = false)
+ private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
+ private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
+ private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
+ private val serializer: Serializer = new JavaSerializer(conf)
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ outputFile = File.createTempFile("shuffle", null, tempDir)
+ shuffleWriteMetrics = new ShuffleWriteMetrics
+ taskMetrics = new TaskMetrics
+ taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
+ MockitoAnnotations.initMocks(this)
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+ when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+ when(blockManager.getDiskWriter(
+ any[BlockId],
+ any[File],
+ any[SerializerInstance],
+ anyInt(),
+ any[ShuffleWriteMetrics]
+ )).thenAnswer(new Answer[BlockObjectWriter] {
+ override def answer(invocation: InvocationOnMock): BlockObjectWriter = {
+ val args = invocation.getArguments
+ new DiskBlockObjectWriter(
+ args(0).asInstanceOf[BlockId],
+ args(1).asInstanceOf[File],
+ args(2).asInstanceOf[SerializerInstance],
+ args(3).asInstanceOf[Int],
+ compressStream = identity,
+ syncWrites = false,
+ args(4).asInstanceOf[ShuffleWriteMetrics]
+ )
+ }
+ })
+ when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+ new Answer[(TempShuffleBlockId, File)] {
+ override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = {
+ val blockId = new TempShuffleBlockId(UUID.randomUUID)
+ val file = File.createTempFile(blockId.toString, null, tempDir)
+ blockIdToFileMap.put(blockId, file)
+ temporaryFilesCreated.append(file)
+ (blockId, file)
+ }
+ })
+ when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
+ new Answer[File] {
+ override def answer(invocation: InvocationOnMock): File = {
+ blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get
+ }
+ })
+ }
+
+ override def afterEach(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ blockIdToFileMap.clear()
+ temporaryFilesCreated.clear()
+ }
+
+ test("write empty iterator") {
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ new SparkConf(loadDefaults = false),
+ blockManager,
+ new HashPartitioner(7),
+ shuffleWriteMetrics,
+ serializer
+ )
+ writer.insertAll(Iterator.empty)
+ val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
+ assert(partitionLengths.sum === 0)
+ assert(outputFile.exists())
+ assert(outputFile.length() === 0)
+ assert(temporaryFilesCreated.isEmpty)
+ assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
+ assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
+
+ test("write with some empty partitions") {
+ def records: Iterator[(Int, Int)] =
+ Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ new SparkConf(loadDefaults = false),
+ blockManager,
+ new HashPartitioner(7),
+ shuffleWriteMetrics,
+ serializer
+ )
+ writer.insertAll(records)
+ assert(temporaryFilesCreated.nonEmpty)
+ val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
+ assert(partitionLengths.sum === outputFile.length())
+ assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+ assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
+ assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
+
+ test("cleanup of intermediate files after errors") {
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ new SparkConf(loadDefaults = false),
+ blockManager,
+ new HashPartitioner(7),
+ shuffleWriteMetrics,
+ serializer
+ )
+ intercept[SparkException] {
+ writer.insertAll((0 until 100000).iterator.map(i => {
+ if (i == 99990) {
+ throw new SparkException("Intentional failure")
+ }
+ (i, i)
+ }))
+ }
+ assert(temporaryFilesCreated.nonEmpty)
+ writer.stop()
+ assert(temporaryFilesCreated.count(_.exists()) === 0)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
new file mode 100644
index 0000000000..c6ada7139c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import org.mockito.Mockito._
+import org.scalatest.FunSuite
+
+import org.apache.spark.{Aggregator, SparkConf}
+
+class SortShuffleWriterSuite extends FunSuite {
+
+ import SortShuffleWriter._
+
+ test("conditions for bypassing merge-sort") {
+ val conf = new SparkConf(loadDefaults = false)
+ val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
+ val ord = implicitly[Ordering[Int]]
+
+ // Numbers of partitions that are above and below the default bypassMergeThreshold
+ val FEW_PARTITIONS = 50
+ val MANY_PARTITIONS = 10000
+
+ // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
+ assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
+ assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
+
+ // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
+ assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
+ assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
index ad43a3e5fd..7bdea724fe 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -18,14 +18,28 @@ package org.apache.spark.storage
import java.io.File
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
-class BlockObjectWriterSuite extends SparkFunSuite {
+class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+ var tempDir: File = _
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ }
+
+ override def afterEach(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ }
+
test("verify write metrics") {
- val file = new File(Utils.createTempDir(), "somefile")
+ val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -47,7 +61,7 @@ class BlockObjectWriterSuite extends SparkFunSuite {
}
test("verify write metrics on revert") {
- val file = new File(Utils.createTempDir(), "somefile")
+ val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -70,7 +84,7 @@ class BlockObjectWriterSuite extends SparkFunSuite {
}
test("Reopening a closed block writer") {
- val file = new File(Utils.createTempDir(), "somefile")
+ val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -81,4 +95,79 @@ class BlockObjectWriterSuite extends SparkFunSuite {
writer.open()
}
}
+
+ test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ for (i <- 1 to 1000) {
+ writer.write(i, i)
+ }
+ writer.commitAndClose()
+ val bytesWritten = writeMetrics.shuffleBytesWritten
+ assert(writeMetrics.shuffleRecordsWritten === 1000)
+ writer.revertPartialWritesAndClose()
+ assert(writeMetrics.shuffleRecordsWritten === 1000)
+ assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+ }
+
+ test("commitAndClose() should be idempotent") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ for (i <- 1 to 1000) {
+ writer.write(i, i)
+ }
+ writer.commitAndClose()
+ val bytesWritten = writeMetrics.shuffleBytesWritten
+ val writeTime = writeMetrics.shuffleWriteTime
+ assert(writeMetrics.shuffleRecordsWritten === 1000)
+ writer.commitAndClose()
+ assert(writeMetrics.shuffleRecordsWritten === 1000)
+ assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+ assert(writeMetrics.shuffleWriteTime === writeTime)
+ }
+
+ test("revertPartialWritesAndClose() should be idempotent") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ for (i <- 1 to 1000) {
+ writer.write(i, i)
+ }
+ writer.revertPartialWritesAndClose()
+ val bytesWritten = writeMetrics.shuffleBytesWritten
+ val writeTime = writeMetrics.shuffleWriteTime
+ assert(writeMetrics.shuffleRecordsWritten === 0)
+ writer.revertPartialWritesAndClose()
+ assert(writeMetrics.shuffleRecordsWritten === 0)
+ assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+ assert(writeMetrics.shuffleWriteTime === writeTime)
+ }
+
+ test("fileSegment() can only be called after commitAndClose() has been called") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ for (i <- 1 to 1000) {
+ writer.write(i, i)
+ }
+ intercept[IllegalStateException] {
+ writer.fileSegment()
+ }
+ writer.close()
+ }
+
+ test("commitAndClose() without ever opening or writing") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+ new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+ writer.commitAndClose()
+ assert(writer.fileSegment().length === 0)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 9039dbef1f..7d7b41bc23 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -23,10 +23,12 @@ import org.scalatest.PrivateMethodTester
import scala.util.Random
+import org.scalatest.FunSuite
+
import org.apache.spark._
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester {
+class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = {
val conf = new SparkConf(loadDefaults)
if (kryo) {
@@ -37,21 +39,12 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
conf.set("spark.serializer.objectStreamReset", "1")
conf.set("spark.serializer", classOf[JavaSerializer].getName)
}
+ conf.set("spark.shuffle.sort.bypassMergeThreshold", "0")
// Ensure that we actually have multiple batches per spill file
conf.set("spark.shuffle.spill.batchSize", "10")
conf
}
- private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
- val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
- assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort")
- }
-
- private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
- val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
- assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort")
- }
-
test("empty data stream with kryo ser") {
emptyDataStream(createSparkConf(false, true))
}
@@ -161,39 +154,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
val sorter = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(7)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter)
- sorter.insertAll(elements)
- assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
- val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
- assert(iter.next() === (0, Nil))
- assert(iter.next() === (1, List((1, 1))))
- assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
- assert(iter.next() === (3, Nil))
- assert(iter.next() === (4, Nil))
- assert(iter.next() === (5, List((5, 5))))
- assert(iter.next() === (6, Nil))
- sorter.stop()
- }
-
- test("empty partitions with spilling, bypass merge-sort with kryo ser") {
- emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true))
- }
-
- test("empty partitions with spilling, bypass merge-sort with java ser") {
- emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false))
- }
-
- def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) {
- conf.set("spark.shuffle.memoryFraction", "0.001")
- conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
- conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local", "test", conf)
-
- val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
-
- val sorter = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(7)), None, None)
- assertBypassedMergeSort(sorter)
sorter.insertAll(elements)
assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -376,7 +336,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
val sorter = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(3)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter)
sorter.insertAll((0 until 120000).iterator.map(i => (i, i)))
assert(diskBlockManager.getAllFiles().length > 0)
sorter.stop()
@@ -384,7 +343,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
val sorter2 = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(3)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter2)
sorter2.insertAll((0 until 120000).iterator.map(i => (i, i)))
assert(diskBlockManager.getAllFiles().length > 0)
assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet)
@@ -392,29 +350,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
assert(diskBlockManager.getAllBlocks().length === 0)
}
- test("cleanup of intermediate files in sorter, bypass merge-sort") {
- val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
- conf.set("spark.shuffle.memoryFraction", "0.001")
- conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local", "test", conf)
- val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
- val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
- assertBypassedMergeSort(sorter)
- sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
- assert(diskBlockManager.getAllFiles().length > 0)
- sorter.stop()
- assert(diskBlockManager.getAllBlocks().length === 0)
-
- val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
- assertBypassedMergeSort(sorter2)
- sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
- assert(diskBlockManager.getAllFiles().length > 0)
- assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
- sorter2.stop()
- assert(diskBlockManager.getAllBlocks().length === 0)
- }
-
test("cleanup of intermediate files in sorter if there are errors") {
val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -426,7 +361,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
val sorter = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(3)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter)
intercept[SparkException] {
sorter.insertAll((0 until 120000).iterator.map(i => {
if (i == 119990) {
@@ -440,28 +374,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
assert(diskBlockManager.getAllBlocks().length === 0)
}
- test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") {
- val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
- conf.set("spark.shuffle.memoryFraction", "0.001")
- conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local", "test", conf)
- val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
- val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
- assertBypassedMergeSort(sorter)
- intercept[SparkException] {
- sorter.insertAll((0 until 100000).iterator.map(i => {
- if (i == 99990) {
- throw new SparkException("Intentional failure")
- }
- (i, i)
- }))
- }
- assert(diskBlockManager.getAllFiles().length > 0)
- sorter.stop()
- assert(diskBlockManager.getAllBlocks().length === 0)
- }
-
test("cleanup of intermediate files in shuffle") {
val conf = createSparkConf(false, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -776,40 +688,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
}
}
- test("conditions for bypassing merge-sort") {
- val conf = createSparkConf(false, false)
- conf.set("spark.shuffle.memoryFraction", "0.001")
- conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local", "test", conf)
-
- val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
- val ord = implicitly[Ordering[Int]]
-
- // Numbers of partitions that are above and below the default bypassMergeThreshold
- val FEW_PARTITIONS = 50
- val MANY_PARTITIONS = 10000
-
- // Sorters with no ordering or aggregator: should bypass unless # of partitions is high
-
- val sorter1 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
- assertBypassedMergeSort(sorter1)
-
- val sorter2 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None)
- assertDidNotBypassMergeSort(sorter2)
-
- // Sorters with an ordering or aggregator: should not bypass even if they have few partitions
-
- val sorter3 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None)
- assertDidNotBypassMergeSort(sorter3)
-
- val sorter4 = new ExternalSorter[Int, Int, Int](
- Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
- assertDidNotBypassMergeSort(sorter4)
- }
-
test("sort without breaking sorting contracts with kryo ser") {
sortWithoutBreakingSortingContracts(createSparkConf(true, true))
}