diff options
Diffstat (limited to 'core/src/main')
-rw-r--r-- | core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java | 56 |
1 files changed, 43 insertions, 13 deletions
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index c40974b54c..39fb3b249d 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -20,8 +20,12 @@ package org.apache.spark.memory; import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.Arrays; +import java.util.ArrayList; import java.util.BitSet; import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -144,23 +148,49 @@ public class TaskMemoryManager { // spilling, avoid to have too many spilled files. if (got < required) { // Call spill() on other consumers to release memory + // Sort the consumers according their memory usage. So we avoid spilling the same consumer + // which is just spilled in last few times and re-spilling on it will produce many small + // spill files. + TreeMap<Long, List<MemoryConsumer>> sortedConsumers = new TreeMap<>(); for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { - try { - long released = c.spill(required - got, consumer); - if (released > 0) { - logger.debug("Task {} released {} from {} for {}", taskAttemptId, - Utils.bytesToString(released), c, consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); - if (got >= required) { - break; - } + long key = c.getUsed(); + List<MemoryConsumer> list = sortedConsumers.get(key); + if (list == null) { + list = new ArrayList<>(1); + sortedConsumers.put(key, list); + } + list.add(c); + } + } + while (!sortedConsumers.isEmpty()) { + // Get the consumer using the least memory more than the remaining required memory. + Map.Entry<Long, List<MemoryConsumer>> currentEntry = + sortedConsumers.ceilingEntry(required - got); + // No consumer has used memory more than the remaining required memory. + // Get the consumer of largest used memory. + if (currentEntry == null) { + currentEntry = sortedConsumers.lastEntry(); + } + List<MemoryConsumer> cList = currentEntry.getValue(); + MemoryConsumer c = cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } + try { + long released = c.spill(required - got, consumer); + if (released > 0) { + logger.debug("Task {} released {} from {} for {}", taskAttemptId, + Utils.bytesToString(released), c, consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + if (got >= required) { + break; } - } catch (IOException e) { - logger.error("error while calling spill() on " + c, e); - throw new OutOfMemoryError("error while calling spill() on " + c + " : " - + e.getMessage()); } + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + throw new OutOfMemoryError("error while calling spill() on " + c + " : " + + e.getMessage()); } } } |