aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java56
-rw-r--r--core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java35
2 files changed, 78 insertions, 13 deletions
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index c40974b54c..39fb3b249d 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -20,8 +20,12 @@ package org.apache.spark.memory;
import javax.annotation.concurrent.GuardedBy;
import java.io.IOException;
import java.util.Arrays;
+import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
@@ -144,23 +148,49 @@ public class TaskMemoryManager {
// spilling, avoid to have too many spilled files.
if (got < required) {
// Call spill() on other consumers to release memory
+ // Sort the consumers according their memory usage. So we avoid spilling the same consumer
+ // which is just spilled in last few times and re-spilling on it will produce many small
+ // spill files.
+ TreeMap<Long, List<MemoryConsumer>> sortedConsumers = new TreeMap<>();
for (MemoryConsumer c: consumers) {
if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) {
- try {
- long released = c.spill(required - got, consumer);
- if (released > 0) {
- logger.debug("Task {} released {} from {} for {}", taskAttemptId,
- Utils.bytesToString(released), c, consumer);
- got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
- if (got >= required) {
- break;
- }
+ long key = c.getUsed();
+ List<MemoryConsumer> list = sortedConsumers.get(key);
+ if (list == null) {
+ list = new ArrayList<>(1);
+ sortedConsumers.put(key, list);
+ }
+ list.add(c);
+ }
+ }
+ while (!sortedConsumers.isEmpty()) {
+ // Get the consumer using the least memory more than the remaining required memory.
+ Map.Entry<Long, List<MemoryConsumer>> currentEntry =
+ sortedConsumers.ceilingEntry(required - got);
+ // No consumer has used memory more than the remaining required memory.
+ // Get the consumer of largest used memory.
+ if (currentEntry == null) {
+ currentEntry = sortedConsumers.lastEntry();
+ }
+ List<MemoryConsumer> cList = currentEntry.getValue();
+ MemoryConsumer c = cList.remove(cList.size() - 1);
+ if (cList.isEmpty()) {
+ sortedConsumers.remove(currentEntry.getKey());
+ }
+ try {
+ long released = c.spill(required - got, consumer);
+ if (released > 0) {
+ logger.debug("Task {} released {} from {} for {}", taskAttemptId,
+ Utils.bytesToString(released), c, consumer);
+ got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
+ if (got >= required) {
+ break;
}
- } catch (IOException e) {
- logger.error("error while calling spill() on " + c, e);
- throw new OutOfMemoryError("error while calling spill() on " + c + " : "
- + e.getMessage());
}
+ } catch (IOException e) {
+ logger.error("error while calling spill() on " + c, e);
+ throw new OutOfMemoryError("error while calling spill() on " + c + " : "
+ + e.getMessage());
}
}
}
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index ad755529de..f53bc0b02b 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -110,6 +110,41 @@ public class TaskMemoryManagerSuite {
}
@Test
+ public void cooperativeSpilling2() {
+ final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
+ memoryManager.limit(100);
+ final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0);
+
+ TestMemoryConsumer c1 = new TestMemoryConsumer(manager);
+ TestMemoryConsumer c2 = new TestMemoryConsumer(manager);
+ TestMemoryConsumer c3 = new TestMemoryConsumer(manager);
+
+ c1.use(20);
+ Assert.assertEquals(20, c1.getUsed());
+ c2.use(80);
+ Assert.assertEquals(80, c2.getUsed());
+ c3.use(80);
+ Assert.assertEquals(20, c1.getUsed()); // c1: not spilled
+ Assert.assertEquals(0, c2.getUsed()); // c2: spilled as it has required size of memory
+ Assert.assertEquals(80, c3.getUsed());
+
+ c2.use(80);
+ Assert.assertEquals(20, c1.getUsed()); // c1: not spilled
+ Assert.assertEquals(0, c3.getUsed()); // c3: spilled as it has required size of memory
+ Assert.assertEquals(80, c2.getUsed());
+
+ c3.use(10);
+ Assert.assertEquals(0, c1.getUsed()); // c1: spilled as it has required size of memory
+ Assert.assertEquals(80, c2.getUsed()); // c2: not spilled as spilling c1 already satisfies c3
+ Assert.assertEquals(10, c3.getUsed());
+
+ c1.free(0);
+ c2.free(80);
+ c3.free(10);
+ Assert.assertEquals(0, manager.cleanUpAllAllocatedMemory());
+ }
+
+ @Test
public void shouldNotForceSpillingInDifferentModes() {
final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
memoryManager.limit(100);