aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java56
1 files changed, 43 insertions, 13 deletions
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index c40974b54c..39fb3b249d 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -20,8 +20,12 @@ package org.apache.spark.memory;
import javax.annotation.concurrent.GuardedBy;
import java.io.IOException;
import java.util.Arrays;
+import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
@@ -144,23 +148,49 @@ public class TaskMemoryManager {
// spilling, avoid to have too many spilled files.
if (got < required) {
// Call spill() on other consumers to release memory
+ // Sort the consumers according their memory usage. So we avoid spilling the same consumer
+ // which is just spilled in last few times and re-spilling on it will produce many small
+ // spill files.
+ TreeMap<Long, List<MemoryConsumer>> sortedConsumers = new TreeMap<>();
for (MemoryConsumer c: consumers) {
if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) {
- try {
- long released = c.spill(required - got, consumer);
- if (released > 0) {
- logger.debug("Task {} released {} from {} for {}", taskAttemptId,
- Utils.bytesToString(released), c, consumer);
- got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
- if (got >= required) {
- break;
- }
+ long key = c.getUsed();
+ List<MemoryConsumer> list = sortedConsumers.get(key);
+ if (list == null) {
+ list = new ArrayList<>(1);
+ sortedConsumers.put(key, list);
+ }
+ list.add(c);
+ }
+ }
+ while (!sortedConsumers.isEmpty()) {
+ // Get the consumer using the least memory more than the remaining required memory.
+ Map.Entry<Long, List<MemoryConsumer>> currentEntry =
+ sortedConsumers.ceilingEntry(required - got);
+ // No consumer has used memory more than the remaining required memory.
+ // Get the consumer of largest used memory.
+ if (currentEntry == null) {
+ currentEntry = sortedConsumers.lastEntry();
+ }
+ List<MemoryConsumer> cList = currentEntry.getValue();
+ MemoryConsumer c = cList.remove(cList.size() - 1);
+ if (cList.isEmpty()) {
+ sortedConsumers.remove(currentEntry.getKey());
+ }
+ try {
+ long released = c.spill(required - got, consumer);
+ if (released > 0) {
+ logger.debug("Task {} released {} from {} for {}", taskAttemptId,
+ Utils.bytesToString(released), c, consumer);
+ got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
+ if (got >= required) {
+ break;
}
- } catch (IOException e) {
- logger.error("error while calling spill() on " + c, e);
- throw new OutOfMemoryError("error while calling spill() on " + c + " : "
- + e.getMessage());
}
+ } catch (IOException e) {
+ logger.error("error while calling spill() on " + c, e);
+ throw new OutOfMemoryError("error while calling spill() on " + c + " : "
+ + e.getMessage());
}
}
}