aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java36
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java12
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala108
4 files changed, 141 insertions, 21 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e6ddd08e5f..8f78fc5a41 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -191,24 +191,29 @@ public final class UnsafeExternalSorter {
spillWriters.size(),
spillWriters.size() > 1 ? " times" : " time");
- final UnsafeSorterSpillWriter spillWriter =
- new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
- inMemSorter.numRecords());
- spillWriters.add(spillWriter);
- final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
- while (sortedRecords.hasNext()) {
- sortedRecords.loadNext();
- final Object baseObject = sortedRecords.getBaseObject();
- final long baseOffset = sortedRecords.getBaseOffset();
- final int recordLength = sortedRecords.getRecordLength();
- spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+ // We only write out contents of the inMemSorter if it is not empty.
+ if (inMemSorter.numRecords() > 0) {
+ final UnsafeSorterSpillWriter spillWriter =
+ new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
+ inMemSorter.numRecords());
+ spillWriters.add(spillWriter);
+ final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final Object baseObject = sortedRecords.getBaseObject();
+ final long baseOffset = sortedRecords.getBaseOffset();
+ final int recordLength = sortedRecords.getRecordLength();
+ spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+ }
+ spillWriter.close();
}
- spillWriter.close();
+
final long spillSize = freeMemory();
// Note that this is more-or-less going to be a multiple of the page size, so wasted space in
// pages will currently be counted as memory spilled even though that space isn't actually
// written to disk. This also counts the space needed to store the sorter's pointer array.
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
initializeForWriting();
}
@@ -505,12 +510,11 @@ public final class UnsafeExternalSorter {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
- spillMerger.addSpill(spillWriter.getReader(blockManager));
+ spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
}
spillWriters.clear();
- if (inMemoryIterator.hasNext()) {
- spillMerger.addSpill(inMemoryIterator);
- }
+ spillMerger.addSpillIfNotEmpty(inMemoryIterator);
+
return spillMerger.getSortedIterator();
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 8272c2a5be..3874a9f9cb 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -47,11 +47,19 @@ final class UnsafeSorterSpillMerger {
priorityQueue = new PriorityQueue<UnsafeSorterIterator>(numSpills, comparator);
}
- public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
+ /**
+ * Add an UnsafeSorterIterator to this merger
+ */
+ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException {
if (spillReader.hasNext()) {
+ // We only add the spillReader to the priorityQueue if it is not empty. We do this to
+ // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator
+ // does not return wrong result because hasNext will returns true
+ // at least priorityQueue.size() times. If we allow n spillReaders in the
+ // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
spillReader.loadNext();
+ priorityQueue.add(spillReader);
}
- priorityQueue.add(spillReader);
}
public UnsafeSorterIterator getSortedIterator() throws IOException {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 86a563df99..6c1cf136d9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -82,8 +82,11 @@ public final class UnsafeKVExternalSorter {
pageSizeBytes);
} else {
// Insert the records into the in-memory sorter.
+ // We will use the number of elements in the map as the initialSize of the
+ // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize,
+ // we will use 1 as its initial size if the map is empty.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
- taskMemoryManager, recordComparator, prefixComparator, map.numElements());
+ taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements()));
final int numKeyFields = keySchema.size();
BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
@@ -214,7 +217,6 @@ public final class UnsafeKVExternalSorter {
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;
-
key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index ef827b0fe9..b513c970cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -23,7 +23,7 @@ import scala.util.{Try, Random}
import org.scalatest.Matchers
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.TestSQLContext
@@ -231,4 +231,110 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
map.free()
}
+
+ testWithMemoryLeakDetection("test external sorting with an empty map") {
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 128, // initial capacity
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+
+ // Convert the map into a sorter
+ val sorter = map.destructAndCreateExternalSorter()
+
+ // Add more keys to the sorter and make sure the results come out sorted.
+ val additionalKeys = randomStrings(1024)
+ val keyConverter = UnsafeProjection.create(groupKeySchema)
+ val valueConverter = UnsafeProjection.create(aggBufferSchema)
+
+ additionalKeys.zipWithIndex.foreach { case (str, i) =>
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(str.length)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+
+ if ((i % 100) == 0) {
+ shuffleMemoryManager.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
+ }
+
+ val out = new scala.collection.mutable.ArrayBuffer[String]
+ val iter = sorter.sortedIterator()
+ while (iter.next()) {
+ // At here, we also test if copy is correct.
+ val key = iter.getKey.copy()
+ val value = iter.getValue.copy()
+ assert(key.getString(0).length === value.getInt(0))
+ out += key.getString(0)
+ }
+
+ assert(out === (additionalKeys).sorted)
+
+ map.free()
+ }
+
+ testWithMemoryLeakDetection("test external sorting with empty records") {
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ // Memory consumption in the beginning of the task.
+ val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ StructType(Nil),
+ StructType(Nil),
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 128, // initial capacity
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+
+ (1 to 10).foreach { i =>
+ val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0))
+ assert(buf != null)
+ }
+
+ // Convert the map into a sorter. Right now, it contains one record.
+ val sorter = map.destructAndCreateExternalSorter()
+
+ withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
+ // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter.
+ assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() ===
+ initialMemoryConsumption + 4096 * 16)
+ }
+
+ // Add more keys to the sorter and make sure the results come out sorted.
+ (1 to 4096).foreach { i =>
+ sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0))
+
+ if ((i % 100) == 0) {
+ shuffleMemoryManager.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
+ }
+
+ var count = 0
+ val iter = sorter.sortedIterator()
+ while (iter.next()) {
+ // At here, we also test if copy is correct.
+ iter.getKey.copy()
+ iter.getValue.copy()
+ count += 1;
+ }
+
+ // 1 record was from the map and 4096 records were explicitly inserted.
+ assert(count === 4097)
+
+ map.free()
+ }
}