aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala179
-rw-r--r--docs/configuration.md4
4 files changed, 93 insertions, 100 deletions
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 530712b5df..696b930a26 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -66,6 +66,11 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
* Cumulative time spent performing blocking writes, in ns.
*/
def timeWriting(): Long
+
+ /**
+ * Number of bytes written so far
+ */
+ def bytesWritten: Long
}
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
@@ -183,7 +188,8 @@ private[spark] class DiskBlockObjectWriter(
// Only valid if called after close()
override def timeWriting() = _timeWriting
- def bytesWritten: Long = {
+ // Only valid if called after commit()
+ override def bytesWritten: Long = {
lastValidPosition - initialPosition
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index a8ef7fa8b6..f3e1c38744 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -50,7 +50,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
addShutdownHook()
/**
- * Returns the phyiscal file segment in which the given BlockId is located.
+ * Returns the physical file segment in which the given BlockId is located.
* If the BlockId has been mapped to a specific FileSegment, that will be returned.
* Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
*/
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 3d9b09ec33..7eb300d46e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -24,11 +24,11 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+import com.google.common.io.ByteStreams
import org.apache.spark.{Logging, SparkEnv}
-import org.apache.spark.io.LZFCompressionCodec
-import org.apache.spark.serializer.{KryoDeserializationStream, Serializer}
-import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockObjectWriter}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{BlockId, BlockManager}
/**
* An append-only map that spills sorted content to disk when there is insufficient space for it
@@ -84,12 +84,15 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
// Number of in-memory pairs inserted before tracking the map's shuffle memory usage
private val trackMemoryThreshold = 1000
- // Size of object batches when reading/writing from serializers. Objects are written in
- // batches, with each batch using its own serialization stream. This cuts down on the size
- // of reference-tracking maps constructed when deserializing a stream.
- //
- // NOTE: Setting this too low can cause excess copying when serializing, since some serializers
- // grow internal data structures by growing + copying every time the number of objects doubles.
+ /**
+ * Size of object batches when reading/writing from serializers.
+ *
+ * Objects are written in batches, with each batch using its own serialization stream. This
+ * cuts down on the size of reference-tracking maps constructed when deserializing a stream.
+ *
+ * NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
+ * grow internal data structures by growing + copying every time the number of objects doubles.
+ */
private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
// How many times we have spilled so far
@@ -100,7 +103,6 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
private var _diskBytesSpilled = 0L
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
- private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false)
private val comparator = new KCComparator[K, C]
private val ser = serializer.newInstance()
@@ -153,37 +155,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
.format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
+ var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
+ var objectsWritten = 0
- /* IMPORTANT NOTE: To avoid having to keep large object graphs in memory, this approach
- * closes and re-opens serialization and compression streams within each file. This makes some
- * assumptions about the way that serialization and compression streams work, specifically:
- *
- * 1) The serializer input streams do not pre-fetch data from the underlying stream.
- *
- * 2) Several compression streams can be opened, written to, and flushed on the write path
- * while only one compression input stream is created on the read path
- *
- * In practice (1) is only true for Java, so we add a special fix below to make it work for
- * Kryo. (2) is only true for LZF and not Snappy, so we coerce this to use LZF.
- *
- * To avoid making these assumptions we should create an intermediate stream that batches
- * objects and sends an EOF to the higher layer streams to make sure they never prefetch data.
- * This is a bit tricky because, within each segment, you'd need to track the total number
- * of bytes written and then re-wind and write it at the beginning of the segment. This will
- * most likely require using the file channel API.
- */
+ // List of batch sizes (bytes) in the order they are written to disk
+ val batchSizes = new ArrayBuffer[Long]
- val shouldCompress = blockManager.shouldCompress(blockId)
- val compressionCodec = new LZFCompressionCodec(sparkConf)
- def wrapForCompression(outputStream: OutputStream) = {
- if (shouldCompress) compressionCodec.compressedOutputStream(outputStream) else outputStream
+ // Flush the disk writer's contents to disk, and update relevant variables
+ def flush() = {
+ writer.commit()
+ val bytesWritten = writer.bytesWritten
+ batchSizes.append(bytesWritten)
+ _diskBytesSpilled += bytesWritten
+ objectsWritten = 0
}
- def getNewWriter = new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize,
- wrapForCompression, syncWrites)
-
- var writer = getNewWriter
- var objectsWritten = 0
try {
val it = currentMap.destructiveSortedIterator(comparator)
while (it.hasNext) {
@@ -192,22 +178,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
objectsWritten += 1
if (objectsWritten == serializerBatchSize) {
- writer.commit()
+ flush()
writer.close()
- _diskBytesSpilled += writer.bytesWritten
- writer = getNewWriter
- objectsWritten = 0
+ writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
}
}
-
- if (objectsWritten > 0) writer.commit()
+ if (objectsWritten > 0) {
+ flush()
+ }
} finally {
// Partial failures cannot be tolerated; do not revert partial writes
writer.close()
- _diskBytesSpilled += writer.bytesWritten
}
+
currentMap = new SizeTrackingAppendOnlyMap[K, C]
- spilledMaps.append(new DiskMapIterator(file, blockId))
+ spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
// Reset the amount of shuffle memory used by this map in the global pool
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
@@ -239,12 +224,12 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
private class ExternalIterator extends Iterator[(K, C)] {
// A fixed-size queue that maintains a buffer for each stream we are currently merging
- val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
+ private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
- val sortedMap = currentMap.destructiveSortedIterator(comparator)
- val inputStreams = Seq(sortedMap) ++ spilledMaps
+ private val sortedMap = currentMap.destructiveSortedIterator(comparator)
+ private val inputStreams = Seq(sortedMap) ++ spilledMaps
inputStreams.foreach { it =>
val kcPairs = getMorePairs(it)
@@ -252,11 +237,12 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
}
/**
- * Fetch from the given iterator until a key of different hash is retrieved. In the
- * event of key hash collisions, this ensures no pairs are hidden from being merged.
+ * Fetch from the given iterator until a key of different hash is retrieved.
+ *
+ * In the event of key hash collisions, this ensures no pairs are hidden from being merged.
* Assume the given iterator is in sorted order.
*/
- def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
+ private def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
val kcPairs = new ArrayBuffer[(K, C)]
if (it.hasNext) {
var kc = it.next()
@@ -274,7 +260,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
* If the given buffer contains a value for the given key, merge that value into
* baseCombiner and remove the corresponding (K, C) pair from the buffer
*/
- def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = {
+ private def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = {
var i = 0
while (i < buffer.pairs.size) {
val (k, c) = buffer.pairs(i)
@@ -293,7 +279,8 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty)
/**
- * Select a key with the minimum hash, then combine all values with the same key from all input streams.
+ * Select a key with the minimum hash, then combine all values with the same key from all
+ * input streams
*/
override def next(): (K, C) = {
// Select a key from the StreamBuffer that holds the lowest key hash
@@ -333,7 +320,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
*
* StreamBuffers are ordered by the minimum key hash found across all of their own pairs.
*/
- case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)])
+ private case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)])
extends Comparable[StreamBuffer] {
def minKeyHash: Int = {
@@ -355,51 +342,53 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
/**
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
*/
- private class DiskMapIterator(file: File, blockId: BlockId) extends Iterator[(K, C)] {
- val fileStream = new FileInputStream(file)
- val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize)
-
- val shouldCompress = blockManager.shouldCompress(blockId)
- val compressionCodec = new LZFCompressionCodec(sparkConf)
- val compressedStream =
- if (shouldCompress) {
- compressionCodec.compressedInputStream(bufferedStream)
+ private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
+ extends Iterator[(K, C)] {
+ private val fileStream = new FileInputStream(file)
+ private val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize)
+
+ // An intermediate stream that reads from exactly one batch
+ // This guards against pre-fetching and other arbitrary behavior of higher level streams
+ private var batchStream = nextBatchStream()
+ private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
+ private var deserializeStream = ser.deserializeStream(compressedStream)
+ private var nextItem: (K, C) = null
+ private var objectsRead = 0
+
+ /**
+ * Construct a stream that reads only from the next batch
+ */
+ private def nextBatchStream(): InputStream = {
+ if (batchSizes.length > 0) {
+ ByteStreams.limit(bufferedStream, batchSizes.remove(0))
} else {
+ // No more batches left
bufferedStream
}
- var deserializeStream = ser.deserializeStream(compressedStream)
- var objectsRead = 0
-
- var nextItem: (K, C) = null
- var eof = false
-
- def readNextItem(): (K, C) = {
- if (!eof) {
- try {
- if (objectsRead == serializerBatchSize) {
- val newInputStream = deserializeStream match {
- case stream: KryoDeserializationStream =>
- // Kryo's serializer stores an internal buffer that pre-fetches from the underlying
- // stream. We need to capture this buffer and feed it to the new serialization
- // stream so that the bytes are not lost.
- val kryoInput = stream.input
- val remainingBytes = kryoInput.limit() - kryoInput.position()
- val extraBuf = kryoInput.readBytes(remainingBytes)
- new SequenceInputStream(new ByteArrayInputStream(extraBuf), compressedStream)
- case _ => compressedStream
- }
- deserializeStream = ser.deserializeStream(newInputStream)
- objectsRead = 0
- }
- objectsRead += 1
- return deserializeStream.readObject().asInstanceOf[(K, C)]
- } catch {
- case e: EOFException =>
- eof = true
- cleanup()
+ }
+
+ /**
+ * Return the next (K, C) pair from the deserialization stream.
+ *
+ * If the current batch is drained, construct a stream for the next batch and read from it.
+ * If no more pairs are left, return null.
+ */
+ private def readNextItem(): (K, C) = {
+ try {
+ val item = deserializeStream.readObject().asInstanceOf[(K, C)]
+ objectsRead += 1
+ if (objectsRead == serializerBatchSize) {
+ batchStream = nextBatchStream()
+ compressedStream = blockManager.wrapForCompression(blockId, batchStream)
+ deserializeStream = ser.deserializeStream(compressedStream)
+ objectsRead = 0
}
+ item
+ } catch {
+ case e: EOFException =>
+ cleanup()
+ null
}
- null
}
override def hasNext: Boolean = {
@@ -419,7 +408,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
}
// TODO: Ensure this gets called even if the iterator isn't drained.
- def cleanup() {
+ private def cleanup() {
deserializeStream.close()
file.delete()
}
diff --git a/docs/configuration.md b/docs/configuration.md
index 1f9fa70566..8e4c48c81f 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -158,9 +158,7 @@ Apart from these, the following properties are also available, and may be useful
<td>spark.shuffle.spill.compress</td>
<td>true</td>
<td>
- Whether to compress data spilled during shuffles. If enabled, spill compression
- always uses the `org.apache.spark.io.LZFCompressionCodec` codec,
- regardless of the value of `spark.io.compression.codec`.
+ Whether to compress data spilled during shuffles.
</td>
</tr>
<tr>