aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorBrian Cho <bcho@fb.com>2016-07-24 19:36:58 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-07-24 19:36:58 -0700
commitdaace6014216b996bcc8937f1fdcea732b6910ca (patch)
treeae328f4d9fb1e11cc0034ed26665082dd92507a8 /core
parent1221ce04029154778ccb5453e348f6d116092cc5 (diff)
downloadspark-daace6014216b996bcc8937f1fdcea732b6910ca.tar.gz
spark-daace6014216b996bcc8937f1fdcea732b6910ca.tar.bz2
spark-daace6014216b996bcc8937f1fdcea732b6910ca.zip
[SPARK-5581][CORE] When writing sorted map output file, avoid open / …
…close between each partition ## What changes were proposed in this pull request? Replace commitAndClose with separate commit and close to avoid opening and closing the file between partitions. ## How was this patch tested? Run existing unit tests, add a few unit tests regarding reverts. Observed a ~20% reduction in total time in tasks on stages with shuffle writes to many partitions. JoshRosen Author: Brian Cho <bcho@fb.com> Closes #13382 from dafrista/separatecommit-master.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java10
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java31
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java3
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala157
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala52
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala67
7 files changed, 192 insertions, 156 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 0e9defe5b4..83dc61c5e5 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -88,6 +88,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
+ private FileSegment[] partitionWriterSegments;
@Nullable private MapStatus mapStatus;
private long[] partitionLengths;
@@ -131,6 +132,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
partitionWriters = new DiskBlockObjectWriter[numPartitions];
+ partitionWriterSegments = new FileSegment[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
@@ -150,8 +152,10 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
- for (DiskBlockObjectWriter writer : partitionWriters) {
- writer.commitAndClose();
+ for (int i = 0; i < numPartitions; i++) {
+ final DiskBlockObjectWriter writer = partitionWriters[i];
+ partitionWriterSegments[i] = writer.commitAndGet();
+ writer.close();
}
File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
@@ -184,7 +188,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
boolean threwException = true;
try {
for (int i = 0; i < numPartitions; i++) {
- final File file = partitionWriters[i].fileSegment().file();
+ final File file = partitionWriterSegments[i].file();
if (file.exists()) {
final FileInputStream in = new FileInputStream(file);
boolean copyThrewException = true;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index cf38a04ed7..cfec724fe9 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -37,6 +37,7 @@ import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.FileSegment;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
@@ -150,10 +151,6 @@ final class ShuffleExternalSorter extends MemoryConsumer {
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();
- // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
- // after SPARK-5581 is fixed.
- DiskBlockObjectWriter writer;
-
// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
// data through a byte array. This array does not need to be large enough to hold a single
@@ -175,7 +172,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
// around this, we pass a dummy no-op serializer.
final SerializerInstance ser = DummySerializerInstance.INSTANCE;
- writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+ final DiskBlockObjectWriter writer =
+ blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
int currentPartition = -1;
while (sortedRecords.hasNext()) {
@@ -185,12 +183,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
if (partition != currentPartition) {
// Switch to the new partition
if (currentPartition != -1) {
- writer.commitAndClose();
- spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ final FileSegment fileSegment = writer.commitAndGet();
+ spillInfo.partitionLengths[currentPartition] = fileSegment.length();
}
currentPartition = partition;
- writer =
- blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
}
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
@@ -209,15 +205,14 @@ final class ShuffleExternalSorter extends MemoryConsumer {
writer.recordWritten();
}
- if (writer != null) {
- writer.commitAndClose();
- // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
- // then the file might be empty. Note that it might be better to avoid calling
- // writeSortedFile() in that case.
- if (currentPartition != -1) {
- spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
- spills.add(spillInfo);
- }
+ final FileSegment committedSegment = writer.commitAndGet();
+ writer.close();
+ // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
+ // then the file might be empty. Note that it might be better to avoid calling
+ // writeSortedFile() in that case.
+ if (currentPartition != -1) {
+ spillInfo.partitionLengths[currentPartition] = committedSegment.length();
+ spills.add(spillInfo);
}
if (!isLastFile) { // i.e. this is a spill file
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 9ba760e842..164b9d70b7 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -136,7 +136,8 @@ public final class UnsafeSorterSpillWriter {
}
public void close() throws IOException {
- writer.commitAndClose();
+ writer.commitAndGet();
+ writer.close();
writer = null;
writeBuffer = null;
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 5b493f470b..e5b1bf2f4b 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -27,8 +27,10 @@ import org.apache.spark.util.Utils
/**
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
- * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to
- * revert partial writes.
+ * to an existing block. For efficiency, it retains the underlying file channel across
+ * multiple commits. This channel is kept open until close() is called. In case of faults,
+ * callers should instead close with revertPartialWritesAndClose() to atomically revert the
+ * uncommitted partial writes.
*
* This class does not support concurrent writes. Also, once the writer has been opened it cannot be
* reopened again.
@@ -46,34 +48,49 @@ private[spark] class DiskBlockObjectWriter(
extends OutputStream
with Logging {
+ /**
+ * Guards against close calls, e.g. from a wrapping stream.
+ * Call manualClose to close the stream that was extended by this trait.
+ * Commit uses this trait to close object streams without paying the
+ * cost of closing and opening the underlying file.
+ */
+ private trait ManualCloseOutputStream extends OutputStream {
+ abstract override def close(): Unit = {
+ flush()
+ }
+
+ def manualClose(): Unit = {
+ super.close()
+ }
+ }
+
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
+ private var mcs: ManualCloseOutputStream = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private var initialized = false
+ private var streamOpen = false
private var hasBeenClosed = false
- private var commitAndCloseHasBeenCalled = false
/**
* Cursors used to represent positions in the file.
*
- * xxxxxxxx|--------|--- |
- * ^ ^ ^
- * | | finalPosition
- * | reportedPosition
- * initialPosition
+ * xxxxxxxxxx|----------|-----|
+ * ^ ^ ^
+ * | | channel.position()
+ * | reportedPosition
+ * committedPosition
*
- * initialPosition: Offset in the file where we start writing. Immutable.
* reportedPosition: Position at the time of the last update to the write metrics.
- * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed.
+ * committedPosition: Offset after last committed write.
* -----: Current writes to the underlying file.
- * xxxxx: Existing contents of the file.
+ * xxxxx: Committed contents of the file.
*/
- private val initialPosition = file.length()
- private var finalPosition: Long = -1
- private var reportedPosition = initialPosition
+ private var committedPosition = file.length()
+ private var reportedPosition = committedPosition
/**
* Keep track of number of records written and also use this to periodically
@@ -81,67 +98,98 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsWritten = 0
+ private def initialize(): Unit = {
+ fos = new FileOutputStream(file, true)
+ channel = fos.getChannel()
+ ts = new TimeTrackingOutputStream(writeMetrics, fos)
+ class ManualCloseBufferedOutputStream
+ extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream
+ mcs = new ManualCloseBufferedOutputStream
+ }
+
def open(): DiskBlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
- fos = new FileOutputStream(file, true)
- ts = new TimeTrackingOutputStream(writeMetrics, fos)
- channel = fos.getChannel()
- bs = compressStream(new BufferedOutputStream(ts, bufferSize))
+ if (!initialized) {
+ initialize()
+ initialized = true
+ }
+ bs = compressStream(mcs)
objOut = serializerInstance.serializeStream(bs)
- initialized = true
+ streamOpen = true
this
}
- override def close() {
+ /**
+ * Close and cleanup all resources.
+ * Should call after committing or reverting partial writes.
+ */
+ private def closeResources(): Unit = {
if (initialized) {
- Utils.tryWithSafeFinally {
- if (syncWrites) {
- // Force outstanding writes to disk and track how long it takes
- objOut.flush()
- val start = System.nanoTime()
- fos.getFD.sync()
- writeMetrics.incWriteTime(System.nanoTime() - start)
- }
- } {
- objOut.close()
- }
-
+ mcs.manualClose()
channel = null
+ mcs = null
bs = null
fos = null
ts = null
objOut = null
initialized = false
+ streamOpen = false
hasBeenClosed = true
}
}
- def isOpen: Boolean = objOut != null
+ /**
+ * Commits any remaining partial writes and closes resources.
+ */
+ override def close() {
+ if (initialized) {
+ Utils.tryWithSafeFinally {
+ commitAndGet()
+ } {
+ closeResources()
+ }
+ }
+ }
/**
* Flush the partial writes and commit them as a single atomic block.
+ * A commit may write additional bytes to frame the atomic block.
+ *
+ * @return file segment with previous offset and length committed on this call.
*/
- def commitAndClose(): Unit = {
- if (initialized) {
+ def commitAndGet(): FileSegment = {
+ if (streamOpen) {
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
// serializer stream and the lower level stream.
objOut.flush()
bs.flush()
- close()
- finalPosition = file.length()
- // In certain compression codecs, more bytes are written after close() is called
- writeMetrics.incBytesWritten(finalPosition - reportedPosition)
+ objOut.close()
+ streamOpen = false
+
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ writeMetrics.incWriteTime(System.nanoTime() - start)
+ }
+
+ val pos = channel.position()
+ val fileSegment = new FileSegment(file, committedPosition, pos - committedPosition)
+ committedPosition = pos
+ // In certain compression codecs, more bytes are written after streams are closed
+ writeMetrics.incBytesWritten(committedPosition - reportedPosition)
+ reportedPosition = committedPosition
+ fileSegment
} else {
- finalPosition = file.length()
+ new FileSegment(file, committedPosition, 0)
}
- commitAndCloseHasBeenCalled = true
}
/**
- * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * Reverts writes that haven't been committed yet. Callers should invoke this function
* when there are runtime exceptions. This method will not throw, though it may be
* unsuccessful in truncating written data.
*
@@ -152,16 +200,15 @@ private[spark] class DiskBlockObjectWriter(
// truncating the file to its initial position.
try {
if (initialized) {
- writeMetrics.decBytesWritten(reportedPosition - initialPosition)
+ writeMetrics.decBytesWritten(reportedPosition - committedPosition)
writeMetrics.decRecordsWritten(numRecordsWritten)
- objOut.flush()
- bs.flush()
- close()
+ streamOpen = false
+ closeResources()
}
val truncateStream = new FileOutputStream(file, true)
try {
- truncateStream.getChannel.truncate(initialPosition)
+ truncateStream.getChannel.truncate(committedPosition)
file
} finally {
truncateStream.close()
@@ -177,7 +224,7 @@ private[spark] class DiskBlockObjectWriter(
* Writes a key-value pair.
*/
def write(key: Any, value: Any) {
- if (!initialized) {
+ if (!streamOpen) {
open()
}
@@ -189,7 +236,7 @@ private[spark] class DiskBlockObjectWriter(
override def write(b: Int): Unit = throw new UnsupportedOperationException()
override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
- if (!initialized) {
+ if (!streamOpen) {
open()
}
@@ -209,18 +256,6 @@ private[spark] class DiskBlockObjectWriter(
}
/**
- * Returns the file segment of committed data that this Writer has written.
- * This is only valid after commitAndClose() has been called.
- */
- def fileSegment(): FileSegment = {
- if (!commitAndCloseHasBeenCalled) {
- throw new IllegalStateException(
- "fileSegment() is only valid after commitAndClose() has been called")
- }
- new FileSegment(file, initialPosition, finalPosition - initialPosition)
- }
-
- /**
* Report the number of bytes written in this writer's shuffle write metrics.
* Note that this is only valid before the underlying streams are closed.
*/
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 6ddc72afde..8c8860bb37 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -105,8 +105,8 @@ class ExternalAppendOnlyMap[K, V, C](
private val fileBufferSize =
sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
- // Write metrics for current spill
- private var curWriteMetrics: ShuffleWriteMetrics = _
+ // Write metrics
+ private val writeMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics()
// Peak size of the in-memory map observed so far, in bytes
private var _peakMemoryUsedBytes: Long = 0L
@@ -206,8 +206,7 @@ class ExternalAppendOnlyMap[K, V, C](
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, C)])
: DiskMapIterator = {
val (blockId, file) = diskBlockManager.createTempLocalBlock()
- curWriteMetrics = new ShuffleWriteMetrics()
- var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
+ val writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics)
var objectsWritten = 0
// List of batch sizes (bytes) in the order they are written to disk
@@ -215,11 +214,9 @@ class ExternalAppendOnlyMap[K, V, C](
// Flush the disk writer's contents to disk, and update relevant variables
def flush(): Unit = {
- val w = writer
- writer = null
- w.commitAndClose()
- _diskBytesSpilled += curWriteMetrics.bytesWritten
- batchSizes.append(curWriteMetrics.bytesWritten)
+ val segment = writer.commitAndGet()
+ batchSizes.append(segment.length)
+ _diskBytesSpilled += segment.length
objectsWritten = 0
}
@@ -232,25 +229,20 @@ class ExternalAppendOnlyMap[K, V, C](
if (objectsWritten == serializerBatchSize) {
flush()
- curWriteMetrics = new ShuffleWriteMetrics()
- writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
}
}
if (objectsWritten > 0) {
flush()
- } else if (writer != null) {
- val w = writer
- writer = null
- w.revertPartialWritesAndClose()
+ writer.close()
+ } else {
+ writer.revertPartialWritesAndClose()
}
success = true
} finally {
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
- if (writer != null) {
- writer.revertPartialWritesAndClose()
- }
+ writer.revertPartialWritesAndClose()
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file}")
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 4067acee73..708a0070e2 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -272,14 +272,9 @@ private[spark] class ExternalSorter[K, V, C](
// These variables are reset after each flush
var objectsWritten: Long = 0
- var spillMetrics: ShuffleWriteMetrics = null
- var writer: DiskBlockObjectWriter = null
- def openWriter(): Unit = {
- assert (writer == null && spillMetrics == null)
- spillMetrics = new ShuffleWriteMetrics
- writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
- }
- openWriter()
+ val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
+ val writer: DiskBlockObjectWriter =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]
@@ -288,14 +283,11 @@ private[spark] class ExternalSorter[K, V, C](
val elementsPerPartition = new Array[Long](numPartitions)
// Flush the disk writer's contents to disk, and update relevant variables.
- // The writer is closed at the end of this process, and cannot be reused.
+ // The writer is committed at the end of this process.
def flush(): Unit = {
- val w = writer
- writer = null
- w.commitAndClose()
- _diskBytesSpilled += spillMetrics.bytesWritten
- batchSizes.append(spillMetrics.bytesWritten)
- spillMetrics = null
+ val segment = writer.commitAndGet()
+ batchSizes.append(segment.length)
+ _diskBytesSpilled += segment.length
objectsWritten = 0
}
@@ -311,24 +303,21 @@ private[spark] class ExternalSorter[K, V, C](
if (objectsWritten == serializerBatchSize) {
flush()
- openWriter()
}
}
if (objectsWritten > 0) {
flush()
- } else if (writer != null) {
- val w = writer
- writer = null
- w.revertPartialWritesAndClose()
+ } else {
+ writer.revertPartialWritesAndClose()
}
success = true
} finally {
- if (!success) {
+ if (success) {
+ writer.close()
+ } else {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
- if (writer != null) {
- writer.revertPartialWritesAndClose()
- }
+ writer.revertPartialWritesAndClose()
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file}")
@@ -693,42 +682,37 @@ private[spark] class ExternalSorter[K, V, C](
blockId: BlockId,
outputFile: File): Array[Long] = {
- val writeMetrics = context.taskMetrics().shuffleWriteMetrics
-
// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
+ val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
+ context.taskMetrics().shuffleWriteMetrics)
if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
- val writer = blockManager.getDiskWriter(
- blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
val partitionId = it.nextPartition()
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
- writer.commitAndClose()
- val segment = writer.fileSegment()
+ val segment = writer.commitAndGet()
lengths(partitionId) = segment.length
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
- val writer = blockManager.getDiskWriter(
- blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
- writer.commitAndClose()
- val segment = writer.fileSegment()
+ val segment = writer.commitAndGet()
lengths(id) = segment.length
}
}
}
+ writer.close()
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index ec4ef4b2fc..059c2c2444 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -60,7 +60,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
}
assert(writeMetrics.bytesWritten > 0)
assert(writeMetrics.recordsWritten === 16385)
- writer.commitAndClose()
+ writer.commitAndGet()
+ writer.close()
assert(file.length() == writeMetrics.bytesWritten)
}
@@ -100,6 +101,40 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
}
}
+ test("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(
+ file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+
+ writer.write(Long.box(20), Long.box(30))
+ val firstSegment = writer.commitAndGet()
+ assert(firstSegment.length === file.length())
+ assert(writeMetrics.shuffleBytesWritten === file.length())
+
+ writer.write(Long.box(40), Long.box(50))
+
+ writer.revertPartialWritesAndClose()
+ assert(firstSegment.length === file.length())
+ assert(writeMetrics.shuffleBytesWritten === file.length())
+ }
+
+ test("calling revertPartialWritesAndClose() after commit() should have no effect") {
+ val file = new File(tempDir, "somefile")
+ val writeMetrics = new ShuffleWriteMetrics()
+ val writer = new DiskBlockObjectWriter(
+ file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+
+ writer.write(Long.box(20), Long.box(30))
+ val firstSegment = writer.commitAndGet()
+ assert(firstSegment.length === file.length())
+ assert(writeMetrics.shuffleBytesWritten === file.length())
+
+ writer.revertPartialWritesAndClose()
+ assert(firstSegment.length === file.length())
+ assert(writeMetrics.shuffleBytesWritten === file.length())
+ }
+
test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
@@ -108,7 +143,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
for (i <- 1 to 1000) {
writer.write(i, i)
}
- writer.commitAndClose()
+ writer.commitAndGet()
+ writer.close()
val bytesWritten = writeMetrics.bytesWritten
assert(writeMetrics.recordsWritten === 1000)
writer.revertPartialWritesAndClose()
@@ -116,7 +152,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
assert(writeMetrics.bytesWritten === bytesWritten)
}
- test("commitAndClose() should be idempotent") {
+ test("commit() and close() should be idempotent") {
val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(
@@ -124,11 +160,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
for (i <- 1 to 1000) {
writer.write(i, i)
}
- writer.commitAndClose()
+ writer.commitAndGet()
+ writer.close()
val bytesWritten = writeMetrics.bytesWritten
val writeTime = writeMetrics.writeTime
assert(writeMetrics.recordsWritten === 1000)
- writer.commitAndClose()
+ writer.commitAndGet()
+ writer.close()
assert(writeMetrics.recordsWritten === 1000)
assert(writeMetrics.bytesWritten === bytesWritten)
assert(writeMetrics.writeTime === writeTime)
@@ -152,26 +190,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
assert(writeMetrics.writeTime === writeTime)
}
- test("fileSegment() can only be called after commitAndClose() has been called") {
+ test("commit() and close() without ever opening or writing") {
val file = new File(tempDir, "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(
file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
- for (i <- 1 to 1000) {
- writer.write(i, i)
- }
- intercept[IllegalStateException] {
- writer.fileSegment()
- }
+ val segment = writer.commitAndGet()
writer.close()
- }
-
- test("commitAndClose() without ever opening or writing") {
- val file = new File(tempDir, "somefile")
- val writeMetrics = new ShuffleWriteMetrics()
- val writer = new DiskBlockObjectWriter(
- file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
- writer.commitAndClose()
- assert(writer.fileSegment().length === 0)
+ assert(segment.length === 0)
}
}