diff options
-rw-r--r-- | core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java | 56 | ||||
-rw-r--r-- | core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java | 35 |
2 files changed, 78 insertions, 13 deletions
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index c40974b54c..39fb3b249d 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -20,8 +20,12 @@ package org.apache.spark.memory; import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.Arrays; +import java.util.ArrayList; import java.util.BitSet; import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -144,23 +148,49 @@ public class TaskMemoryManager { // spilling, avoid to have too many spilled files. if (got < required) { // Call spill() on other consumers to release memory + // Sort the consumers according their memory usage. So we avoid spilling the same consumer + // which is just spilled in last few times and re-spilling on it will produce many small + // spill files. + TreeMap<Long, List<MemoryConsumer>> sortedConsumers = new TreeMap<>(); for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { - try { - long released = c.spill(required - got, consumer); - if (released > 0) { - logger.debug("Task {} released {} from {} for {}", taskAttemptId, - Utils.bytesToString(released), c, consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); - if (got >= required) { - break; - } + long key = c.getUsed(); + List<MemoryConsumer> list = sortedConsumers.get(key); + if (list == null) { + list = new ArrayList<>(1); + sortedConsumers.put(key, list); + } + list.add(c); + } + } + while (!sortedConsumers.isEmpty()) { + // Get the consumer using the least memory more than the remaining required memory. + Map.Entry<Long, List<MemoryConsumer>> currentEntry = + sortedConsumers.ceilingEntry(required - got); + // No consumer has used memory more than the remaining required memory. + // Get the consumer of largest used memory. + if (currentEntry == null) { + currentEntry = sortedConsumers.lastEntry(); + } + List<MemoryConsumer> cList = currentEntry.getValue(); + MemoryConsumer c = cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } + try { + long released = c.spill(required - got, consumer); + if (released > 0) { + logger.debug("Task {} released {} from {} for {}", taskAttemptId, + Utils.bytesToString(released), c, consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + if (got >= required) { + break; } - } catch (IOException e) { - logger.error("error while calling spill() on " + c, e); - throw new OutOfMemoryError("error while calling spill() on " + c + " : " - + e.getMessage()); } + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + throw new OutOfMemoryError("error while calling spill() on " + c + " : " + + e.getMessage()); } } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index ad755529de..f53bc0b02b 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -110,6 +110,41 @@ public class TaskMemoryManagerSuite { } @Test + public void cooperativeSpilling2() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(100); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0); + + TestMemoryConsumer c1 = new TestMemoryConsumer(manager); + TestMemoryConsumer c2 = new TestMemoryConsumer(manager); + TestMemoryConsumer c3 = new TestMemoryConsumer(manager); + + c1.use(20); + Assert.assertEquals(20, c1.getUsed()); + c2.use(80); + Assert.assertEquals(80, c2.getUsed()); + c3.use(80); + Assert.assertEquals(20, c1.getUsed()); // c1: not spilled + Assert.assertEquals(0, c2.getUsed()); // c2: spilled as it has required size of memory + Assert.assertEquals(80, c3.getUsed()); + + c2.use(80); + Assert.assertEquals(20, c1.getUsed()); // c1: not spilled + Assert.assertEquals(0, c3.getUsed()); // c3: spilled as it has required size of memory + Assert.assertEquals(80, c2.getUsed()); + + c3.use(10); + Assert.assertEquals(0, c1.getUsed()); // c1: spilled as it has required size of memory + Assert.assertEquals(80, c2.getUsed()); // c2: not spilled as spilling c1 already satisfies c3 + Assert.assertEquals(10, c3.getUsed()); + + c1.free(0); + c2.free(80); + c3.free(10); + Assert.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + } + + @Test public void shouldNotForceSpillingInDifferentModes() { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); memoryManager.limit(100); |