diff options
Diffstat (limited to 'core/src/test/java')
3 files changed, 20 insertions, 41 deletions
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 47c695ad4e..44733dcdaf 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -70,6 +70,7 @@ public class UnsafeShuffleWriterSuite { final LinkedList<File> spillFilesCreated = new LinkedList<>(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); + final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf()); TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -111,7 +112,7 @@ public class UnsafeShuffleWriterSuite { .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); memoryManager = new TestMemoryManager(conf); - taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -135,35 +136,6 @@ public class UnsafeShuffleWriterSuite { ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( - new Answer<InputStream>() { - @Override - public InputStream answer(InvocationOnMock invocation) throws Throwable { - assertTrue(invocation.getArguments()[0] instanceof TempShuffleBlockId); - InputStream is = (InputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); - } else { - return is; - } - } - } - ); - - when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( - new Answer<OutputStream>() { - @Override - public OutputStream answer(InvocationOnMock invocation) throws Throwable { - assertTrue(invocation.getArguments()[0] instanceof TempShuffleBlockId); - OutputStream os = (OutputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); - } else { - return os; - } - } - } - ); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(new Answer<Void>() { diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 6667179b9d..449fb45c30 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -19,7 +19,6 @@ package org.apache.spark.unsafe.map; import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; @@ -42,7 +41,9 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -51,7 +52,6 @@ import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; @@ -64,6 +64,9 @@ public abstract class AbstractBytesToBytesMapSuite { private TestMemoryManager memoryManager; private TaskMemoryManager taskMemoryManager; + private SerializerManager serializerManager = new SerializerManager( + new JavaSerializer(new SparkConf()), + new SparkConf().set("spark.shuffle.spill.compress", "false")); private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes final LinkedList<File> spillFilesCreated = new LinkedList<>(); @@ -85,7 +88,9 @@ public abstract class AbstractBytesToBytesMapSuite { new TestMemoryManager( new SparkConf() .set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator()) - .set("spark.memory.offHeap.size", "256mb")); + .set("spark.memory.offHeap.size", "256mb") + .set("spark.shuffle.spill.compress", "false") + .set("spark.shuffle.compress", "false")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); @@ -124,8 +129,6 @@ public abstract class AbstractBytesToBytesMapSuite { ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); } @After @@ -546,8 +549,8 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void spillInIterator() throws IOException { - BytesToBytesMap map = - new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false); + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false); try { int i; for (i = 0; i < 1024; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index db50e551f2..a2253d8559 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,7 +19,6 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.util.Arrays; import java.util.LinkedList; @@ -43,14 +42,15 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -60,6 +60,9 @@ public class UnsafeExternalSorterSuite { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final SerializerManager serializerManager = new SerializerManager( + new JavaSerializer(new SparkConf()), + new SparkConf().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @Override @@ -135,8 +138,6 @@ public class UnsafeExternalSorterSuite { ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); } @After @@ -172,6 +173,7 @@ public class UnsafeExternalSorterSuite { return UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, @@ -374,6 +376,7 @@ public class UnsafeExternalSorterSuite { final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, null, null, @@ -408,6 +411,7 @@ public class UnsafeExternalSorterSuite { final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, |