aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala48
1 files changed, 13 insertions, 35 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 1f7d2dc838..cc0423856c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -71,13 +71,7 @@ class ExternalAppendOnlyMap[K, V, C](
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
private val sparkConf = SparkEnv.get.conf
private val diskBlockManager = blockManager.diskBlockManager
-
- // Collective memory threshold shared across all running tasks
- private val maxMemoryThreshold = {
- val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2)
- val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8)
- (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
- }
+ private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
// Number of pairs inserted since last spill; note that we count them even if a value is merged
// with a previous key in case we're doing something like groupBy where the result grows
@@ -140,28 +134,15 @@ class ExternalAppendOnlyMap[K, V, C](
if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
currentMap.estimateSize() >= myMemoryThreshold)
{
- val currentSize = currentMap.estimateSize()
- var shouldSpill = false
- val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
-
- // Atomically check whether there is sufficient memory in the global pool for
- // this map to grow and, if possible, allocate the required amount
- shuffleMemoryMap.synchronized {
- val threadId = Thread.currentThread().getId
- val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
- val availableMemory = maxMemoryThreshold -
- (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
-
- // Try to allocate at least 2x more memory, otherwise spill
- shouldSpill = availableMemory < currentSize * 2
- if (!shouldSpill) {
- shuffleMemoryMap(threadId) = currentSize * 2
- myMemoryThreshold = currentSize * 2
- }
- }
- // Do not synchronize spills
- if (shouldSpill) {
- spill(currentSize)
+ // Claim up to double our current memory from the shuffle memory pool
+ val currentMemory = currentMap.estimateSize()
+ val amountToRequest = 2 * currentMemory - myMemoryThreshold
+ val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+ myMemoryThreshold += granted
+ if (myMemoryThreshold <= currentMemory) {
+ // We were granted too little memory to grow further (either tryToAcquire returned 0,
+ // or we already had more memory than myMemoryThreshold); spill the current collection
+ spill(currentMemory) // Will also release memory back to ShuffleMemoryManager
}
}
currentMap.changeValue(curEntry._1, update)
@@ -245,12 +226,9 @@ class ExternalAppendOnlyMap[K, V, C](
currentMap = new SizeTrackingAppendOnlyMap[K, C]
spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
- // Reset the amount of shuffle memory used by this map in the global pool
- val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
- shuffleMemoryMap.synchronized {
- shuffleMemoryMap(Thread.currentThread().getId) = 0
- }
- myMemoryThreshold = 0
+ // Release our memory back to the shuffle pool so that other threads can grab it
+ shuffleMemoryManager.release(myMemoryThreshold)
+ myMemoryThreshold = 0L
elementsRead = 0
_memoryBytesSpilled += mapSize