diff options
Diffstat (limited to 'core')
3 files changed, 43 insertions, 10 deletions
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 8b30cd4bfe..ee4467085f 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -32,7 +32,7 @@ case class Aggregator[K, V, C] ( mergeCombiners: (C, C) => C) { private val sparkConf = SparkEnv.get.conf - private val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true) + private val externalSorting = sparkConf.getBoolean("spark.shuffle.external", true) def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { if (!externalSorting) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 6f1345c57a..0e770ed152 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -80,6 +80,8 @@ private[spark] class BlockManager( val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) // Whether to compress RDD partitions that are stored serialized val compressRdds = conf.getBoolean("spark.rdd.compress", false) + // Whether to compress shuffle output temporarily spilled to disk + val compressExternalShuffle = conf.getBoolean("spark.shuffle.external.compress", false) val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) @@ -790,6 +792,7 @@ private[spark] class BlockManager( case ShuffleBlockId(_, _, _) => compressShuffle case BroadcastBlockId(_) => compressBroadcast case RDDBlockId(_, _) => compressRdds + case TempBlockId(_) => compressExternalShuffle case _ => false } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index e3bcd895aa..fd17413952 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter} +import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockManager, DiskBlockObjectWriter} /** * An append-only map that spills sorted content to disk when there is insufficient space for it @@ -60,7 +60,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializerManager.default, - diskBlockManager: DiskBlockManager = SparkEnv.get.blockManager.diskBlockManager) + blockManager: BlockManager = SparkEnv.get.blockManager) extends Iterable[(K, C)] with Serializable with Logging { import ExternalAppendOnlyMap._ @@ -68,6 +68,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf + private val diskBlockManager = blockManager.diskBlockManager // Collective memory threshold shared across all running tasks private val maxMemoryThreshold = { @@ -82,6 +83,14 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( // Number of in-memory pairs inserted before tracking the map's shuffle memory usage private val trackMemoryThreshold = 1000 + // Size of object batches when reading/writing from serializers. Objects are written in + // batches, with each batch using its own serialization stream. This cuts down on the size + // of reference-tracking maps constructed when deserializing a stream. + // + // NOTE: Setting this too low can cause excess copying when serializing, since some serailizers + // grow internal data structures by growing + copying every time the number of objects doubles. + private val serializerBatchSize = sparkConf.getLong("spark.shuffle.external.batchSize", 10000) + // How many times we have spilled so far private var spillCount = 0 @@ -139,21 +148,34 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)" .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() - val writer = - new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites) + + val compressStream: OutputStream => OutputStream = blockManager.wrapForCompression(blockId, _) + def getNewWriter = new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, + compressStream, syncWrites) + + var writer = getNewWriter + var objectsWritten = 0 try { val it = currentMap.destructiveSortedIterator(comparator) while (it.hasNext) { val kv = it.next() writer.write(kv) + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + writer.commit() + writer = getNewWriter + objectsWritten = 0 + } } - writer.commit() + + if (objectsWritten > 0) writer.commit() } finally { // Partial failures cannot be tolerated; do not revert partial writes writer.close() } currentMap = new SizeTrackingAppendOnlyMap[K, C] - spilledMaps.append(new DiskMapIterator(file)) + spilledMaps.append(new DiskMapIterator(file, blockId)) // Reset the amount of shuffle memory used by this map in the global pool val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap @@ -297,16 +319,24 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( /** * An iterator that returns (K, C) pairs in sorted order from an on-disk map */ - private class DiskMapIterator(file: File) extends Iterator[(K, C)] { + private class DiskMapIterator(file: File, blockId: BlockId) extends Iterator[(K, C)] { val fileStream = new FileInputStream(file) - val bufferedStream = new FastBufferedInputStream(fileStream) - val deserializeStream = ser.deserializeStream(bufferedStream) + val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize) + val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + var deserializeStream = ser.deserializeStream(compressedStream) + var objectsRead = 0 + var nextItem: (K, C) = null var eof = false def readNextItem(): (K, C) = { if (!eof) { try { + if (objectsRead == serializerBatchSize) { + deserializeStream = ser.deserializeStream(compressedStream) + objectsRead = 0 + } + objectsRead += 1 return deserializeStream.readObject().asInstanceOf[(K, C)] } catch { case e: EOFException => |