aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java')
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java53
1 files changed, 24 insertions, 29 deletions
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 29d9823b1f..d65926949c 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -39,7 +39,6 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;
@@ -54,19 +53,15 @@ import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.serializer.*;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
import org.apache.spark.storage.*;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
public class UnsafeShuffleWriterSuite {
static final int NUM_PARTITITONS = 4;
- final TaskMemoryManager taskMemoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ TaskMemoryManager taskMemoryManager;
final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
File mergedOutputFile;
File tempDir;
@@ -76,7 +71,6 @@ public class UnsafeShuffleWriterSuite {
final Serializer serializer = new KryoSerializer(new SparkConf());
TaskMetrics taskMetrics;
- @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@@ -111,11 +105,11 @@ public class UnsafeShuffleWriterSuite {
mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
partitionSizesInMergedFile = null;
spillFilesCreated.clear();
- conf = new SparkConf().set("spark.buffer.pageSize", "128m");
+ conf = new SparkConf()
+ .set("spark.buffer.pageSize", "128m")
+ .set("spark.unsafe.offHeap", "false");
taskMetrics = new TaskMetrics();
-
- when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
- when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
+ taskMemoryManager = new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(blockManager.getDiskWriter(
@@ -203,7 +197,6 @@ public class UnsafeShuffleWriterSuite {
blockManager,
shuffleBlockResolver,
taskMemoryManager,
- shuffleMemoryManager,
new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
0, // map id
taskContext,
@@ -405,11 +398,12 @@ public class UnsafeShuffleWriterSuite {
@Test
public void writeEnoughDataToTriggerSpill() throws Exception {
- when(shuffleMemoryManager.tryToAcquire(anyLong()))
- .then(returnsFirstArg()) // Allocate initial sort buffer
- .then(returnsFirstArg()) // Allocate initial data page
- .thenReturn(0L) // Deny request to allocate new data page
- .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ taskMemoryManager = spy(taskMemoryManager);
+ doCallRealMethod() // initialize sort buffer
+ .doCallRealMethod() // allocate initial data page
+ .doReturn(0L) // deny request to allocate new page
+ .doCallRealMethod() // grant new sort buffer and data page
+ .when(taskMemoryManager).acquireExecutionMemory(anyLong());
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
@@ -417,7 +411,7 @@ public class UnsafeShuffleWriterSuite {
dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
}
writer.write(dataToWrite.iterator());
- verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
assertEquals(2, spillFilesCreated.size());
writer.stop(true);
readRecordsFromFile();
@@ -432,18 +426,19 @@ public class UnsafeShuffleWriterSuite {
@Test
public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
- when(shuffleMemoryManager.tryToAcquire(anyLong()))
- .then(returnsFirstArg()) // Allocate initial sort buffer
- .then(returnsFirstArg()) // Allocate initial data page
- .thenReturn(0L) // Deny request to grow sort buffer
- .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ taskMemoryManager = spy(taskMemoryManager);
+ doCallRealMethod() // initialize sort buffer
+ .doCallRealMethod() // allocate initial data page
+ .doReturn(0L) // deny request to allocate new page
+ .doCallRealMethod() // grant new sort buffer and data page
+ .when(taskMemoryManager).acquireExecutionMemory(anyLong());
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
- final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
dataToWrite.add(new Tuple2<Object, Object>(i, i));
}
writer.write(dataToWrite.iterator());
- verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
assertEquals(2, spillFilesCreated.size());
writer.stop(true);
readRecordsFromFile();
@@ -509,13 +504,13 @@ public class UnsafeShuffleWriterSuite {
final long recordLengthBytes = 8;
final long pageSizeBytes = 256;
final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
- when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
+ taskMemoryManager = spy(taskMemoryManager);
+ when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
final UnsafeShuffleWriter<Object, Object> writer =
new UnsafeShuffleWriter<Object, Object>(
blockManager,
shuffleBlockResolver,
taskMemoryManager,
- shuffleMemoryManager,
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,