aboutsummaryrefslogtreecommitdiff
path: root/core/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test')
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java202
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java139
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala50
3 files changed, 391 insertions, 0 deletions
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
new file mode 100644
index 0000000000..ea8755e21e
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.UUID;
+
+import scala.Tuple2;
+import scala.Tuple2$;
+import scala.runtime.AbstractFunction1;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeExternalSorterSuite {
+
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+
+ File tempDir;
+
+ private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ return stream;
+ }
+ }
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ tempDir = new File(Utils.createTempDir$default$1());
+ taskContext = mock(TaskContext.class);
+ when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
+ when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
+ @Override
+ public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
+ TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (BlockId) args[0],
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
+ .then(returnsSecondArg());
+ }
+
+ private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
+ final int[] arr = new int[] { value };
+ sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
+ }
+
+ @Test
+ public void testSortingOnlyByPrefix() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ insertNumber(sorter, 5);
+ insertNumber(sorter, 1);
+ insertNumber(sorter, 3);
+ sorter.spill();
+ insertNumber(sorter, 4);
+ sorter.spill();
+ insertNumber(sorter, 2);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(i, iter.getKeyPrefix());
+ assertEquals(4, iter.getRecordLength());
+ // TODO: read rest of value.
+ }
+
+ // TODO: test for cleanup:
+ // assert(tempDir.isEmpty)
+ }
+
+ @Test
+ public void testSortingEmptyArrays() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(0, iter.getKeyPrefix());
+ assertEquals(0, iter.getRecordLength());
+ }
+ }
+
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
new file mode 100644
index 0000000000..9095009305
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Arrays;
+
+import org.junit.Test;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.*;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.mock;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeInMemorySorterSuite {
+
+ private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) {
+ final byte[] strBytes = new byte[length];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, length);
+ return new String(strBytes);
+ }
+
+ @Test
+ public void testSortingEmptyInput() {
+ final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
+ mock(RecordComparator.class),
+ mock(PrefixComparator.class),
+ 100);
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ assert(!iter.hasNext());
+ }
+
+ @Test
+ public void testSortingOnlyByIntegerPrefix() throws Exception {
+ final String[] dataToSort = new String[] {
+ "Boba",
+ "Pearls",
+ "Tapioca",
+ "Taho",
+ "Condensed Milk",
+ "Jasmine",
+ "Milk Tea",
+ "Lychee",
+ "Mango"
+ };
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ final Object baseObject = dataPage.getBaseObject();
+ // Write the records into the data page:
+ long position = dataPage.getBaseOffset();
+ for (String str : dataToSort) {
+ final byte[] strBytes = str.getBytes("utf-8");
+ PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+ position += 4;
+ PlatformDependent.copyMemory(
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ position,
+ strBytes.length);
+ position += strBytes.length;
+ }
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+ // Compute key prefixes based on the records' partition ids
+ final HashPartitioner hashPartitioner = new HashPartitioner(4);
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
+ prefixComparator, dataToSort.length);
+ // Given a page of records, insert those records into the sorter one-by-one:
+ position = dataPage.getBaseOffset();
+ for (int i = 0; i < dataToSort.length; i++) {
+ // position now points to the start of a record (which holds its length).
+ final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position);
+ final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
+ final String str = getStringFromDataPage(baseObject, position + 4, recordLength);
+ final int partitionId = hashPartitioner.getPartition(str);
+ sorter.insertRecord(address, partitionId);
+ position += 4 + recordLength;
+ }
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ int iterLength = 0;
+ long prevPrefix = -1;
+ Arrays.sort(dataToSort);
+ while (iter.hasNext()) {
+ iter.loadNext();
+ final String str =
+ getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength());
+ final long keyPrefix = iter.getKeyPrefix();
+ assertThat(str, isIn(Arrays.asList(dataToSort)));
+ assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));
+ prevPrefix = keyPrefix;
+ iterLength++;
+ }
+ assertEquals(dataToSort.length, iterLength);
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
new file mode 100644
index 0000000000..dd505dfa7d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort
+
+import org.scalatest.prop.PropertyChecks
+
+import org.apache.spark.SparkFunSuite
+
+class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
+
+ test("String prefix comparator") {
+
+ def testPrefixComparison(s1: String, s2: String): Unit = {
+ val s1Prefix = PrefixComparators.STRING.computePrefix(s1)
+ val s2Prefix = PrefixComparators.STRING.computePrefix(s2)
+ val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
+ assert(
+ (prefixComparisonResult == 0) ||
+ (prefixComparisonResult < 0 && s1 < s2) ||
+ (prefixComparisonResult > 0 && s1 > s2))
+ }
+
+ // scalastyle:off
+ val regressionTests = Table(
+ ("s1", "s2"),
+ ("abc", "世界"),
+ ("你好", "世界"),
+ ("你好123", "你好122")
+ )
+ // scalastyle:on
+
+ forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ }
+}