aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-08-02 12:32:14 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-08-02 12:32:14 -0700
commit2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f (patch)
treef7458ae297d36bba1acf21fd08169defef6c2ef8 /sql
parent66924ffa6bdb8e0df1b90b789cb7ad443377e729 (diff)
downloadspark-2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f.tar.gz
spark-2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f.tar.bz2
spark-2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f.zip
[SPARK-9531] [SQL] UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter
This pull request adds a destructAndCreateExternalSorter method to UnsafeFixedWidthAggregationMap. The new method does the following: 1. Creates a new external sorter UnsafeKVExternalSorter 2. Adds all the data into an in-memory sorter, sorts them 3. Spills the sorted in-memory data to disk This method can be used to fallback to sort-based aggregation when under memory pressure. The pull request also includes accounting fixes from JoshRosen. TODOs (that can be done in follow-up PRs) - [x] Address Josh's feedbacks from #7849 - [x] More documentation and test cases - [x] Make sure we are doing memory accounting correctly with test cases (e.g. did we release the memory in BytesToBytesMap twice?) - [ ] Look harder at possible memory leaks and exception handling - [ ] Randomized tester for the KV sorter as well as the aggregation map Author: Reynold Xin <rxin@databricks.com> Author: Josh Rosen <joshrosen@databricks.com> Closes #7860 from rxin/kvsorter and squashes the following commits: 986a58c [Reynold Xin] Bug fix. 599317c [Reynold Xin] Style fix and slightly more compact code. fe7bd4e [Reynold Xin] Bug fixes. fd71bef [Reynold Xin] Merge remote-tracking branch 'josh/large-records-in-sql-sorter' into kvsorter-with-josh-fix 3efae38 [Reynold Xin] More fixes and documentation. 45f1b09 [Josh Rosen] Ensure that spill files are cleaned up f6a9bd3 [Reynold Xin] Josh feedback. 9be8139 [Reynold Xin] Remove testSpillFrequency. 7cbe759 [Reynold Xin] [SPARK-9531][SQL] UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter. ae4a8af [Josh Rosen] Detect leaked unsafe memory in UnsafeExternalSorterSuite. 52f9b06 [Josh Rosen] Detect ShuffleMemoryManager leaks in UnsafeExternalSorter.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java3
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java9
-rw-r--r--sql/core/pom.xml5
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java103
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java236
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala51
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala124
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala158
10 files changed, 586 insertions, 140 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 1b475b2492..b4fc0b7b70 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -507,7 +507,8 @@ public final class UnsafeRow extends MutableRow {
public String toString() {
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
- build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
+ build.append(java.lang.Long.toHexString(
+ PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)));
build.append(',');
}
build.append(']');
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 68c49feae9..5e4c6232c9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -59,20 +59,21 @@ final class UnsafeExternalRowSorter {
StructType schema,
Ordering<InternalRow> ordering,
PrefixComparator prefixComparator,
- PrefixComputer prefixComputer) throws IOException {
+ PrefixComputer prefixComputer,
+ long pageSizeBytes) throws IOException {
this.schema = schema;
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
final TaskContext taskContext = TaskContext.get();
- sorter = new UnsafeExternalSorter(
+ sorter = UnsafeExternalSorter.create(
taskContext.taskMemoryManager(),
sparkEnv.shuffleMemoryManager(),
sparkEnv.blockManager(),
taskContext,
new RowComparator(ordering, schema.length()),
prefixComparator,
- 4096,
- sparkEnv.conf()
+ /* initialSize */ 4096,
+ pageSizeBytes
);
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index be0966641b..349007789f 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -106,6 +106,11 @@
<artifactId>parquet-avro</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index a0a8dd5154..9e2c9334a7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -19,24 +19,18 @@ package org.apache.spark.sql.execution;
import java.io.IOException;
+import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
-import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
-import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
-import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter;
-import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -215,7 +209,7 @@ public final class UnsafeFixedWidthAggregationMap {
}
/**
- * Free the unsafe memory associated with this map.
+ * Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
public void free() {
map.free();
@@ -233,92 +227,17 @@ public final class UnsafeFixedWidthAggregationMap {
}
/**
- * Sorts the key, value data in this map in place, and return them as an iterator.
+ * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]]
+ * that can be used to insert more records to do external sorting.
*
* The only memory that is allocated is the address/prefix array, 16 bytes per record.
+ *
+ * Note that this destroys the map, and as a result, the map cannot be used anymore after this.
*/
- public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() {
- int numElements = map.numElements();
- final int numKeyFields = groupingKeySchema.size();
- TaskMemoryManager memoryManager = map.getTaskMemoryManager();
-
- UnsafeExternalRowSorter.PrefixComputer prefixComp =
- SortPrefixUtils.createPrefixGenerator(groupingKeySchema);
- PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(groupingKeySchema);
-
- final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema);
- RecordComparator recordComparator = new RecordComparator() {
- private final UnsafeRow row1 = new UnsafeRow();
- private final UnsafeRow row2 = new UnsafeRow();
-
- @Override
- public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
- row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
- row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
- return ordering.compare(row1, row2);
- }
- };
-
- // Insert the records into the in-memory sorter.
- final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
- memoryManager, recordComparator, prefixComparator, numElements);
-
- BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
- UnsafeRow row = new UnsafeRow();
- while (iter.hasNext()) {
- final BytesToBytesMap.Location loc = iter.next();
- final Object baseObject = loc.getKeyAddress().getBaseObject();
- final long baseOffset = loc.getKeyAddress().getBaseOffset();
-
- // Get encoded memory address
- MemoryBlock page = loc.getMemoryPage();
- long address = memoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
-
- // Compute prefix
- row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
- final long prefix = prefixComp.computePrefix(row);
-
- sorter.insertRecord(address, prefix);
- }
-
- // Return the sorted result as an iterator.
- return new KVIterator<UnsafeRow, UnsafeRow>() {
-
- private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
- private final UnsafeRow key = new UnsafeRow();
- private final UnsafeRow value = new UnsafeRow();
- private int numValueFields = aggregationBufferSchema.size();
-
- @Override
- public boolean next() throws IOException {
- if (sortedIterator.hasNext()) {
- sortedIterator.loadNext();
- Object baseObj = sortedIterator.getBaseObject();
- long recordOffset = sortedIterator.getBaseOffset();
- int recordLen = sortedIterator.getRecordLength();
- int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
- key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
- value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, recordLen - keyLen);
- return true;
- } else {
- return false;
- }
- }
-
- @Override
- public UnsafeRow getKey() {
- return key;
- }
-
- @Override
- public UnsafeRow getValue() {
- return value;
- }
-
- @Override
- public void close() {
- // Do nothing
- }
- };
+ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
+ UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter(
+ groupingKeySchema, aggregationBufferSchema,
+ SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map);
+ return sorter;
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
new file mode 100644
index 0000000000..f6b0176863
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import javax.annotation.Nullable;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.KVIterator;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.*;
+
+/**
+ * A class for performing external sorting on key-value records. Both key and value are UnsafeRows.
+ *
+ * Note that this class allows optionally passing in a {@link BytesToBytesMap} directly in order
+ * to perform in-place sorting of records in the map.
+ */
+public final class UnsafeKVExternalSorter {
+
+ private final StructType keySchema;
+ private final StructType valueSchema;
+ private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
+ private final UnsafeExternalSorter sorter;
+
+ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
+ BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes)
+ throws IOException {
+ this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null);
+ }
+
+ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
+ BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes,
+ @Nullable BytesToBytesMap map) throws IOException {
+ this.keySchema = keySchema;
+ this.valueSchema = valueSchema;
+ final TaskContext taskContext = TaskContext.get();
+
+ prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema);
+ PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema);
+ BaseOrdering ordering = GenerateOrdering.create(keySchema);
+ KVComparator recordComparator = new KVComparator(ordering, keySchema.length());
+
+ TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager();
+
+ if (map == null) {
+ sorter = UnsafeExternalSorter.create(
+ taskMemoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ /* initialSize */ 4096,
+ pageSizeBytes);
+ } else {
+ // Insert the records into the in-memory sorter.
+ final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
+ taskMemoryManager, recordComparator, prefixComparator, map.numElements());
+
+ final int numKeyFields = keySchema.size();
+ BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+ UnsafeRow row = new UnsafeRow();
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ final Object baseObject = loc.getKeyAddress().getBaseObject();
+ final long baseOffset = loc.getKeyAddress().getBaseOffset();
+
+ // Get encoded memory address
+ // baseObject + baseOffset point to the beginning of the key data in the map, but that
+ // the KV-pair's length data is stored in the word immediately before that address
+ MemoryBlock page = loc.getMemoryPage();
+ long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
+
+ // Compute prefix
+ row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
+ final long prefix = prefixComputer.computePrefix(row);
+
+ inMemSorter.insertRecord(address, prefix);
+ }
+
+ sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
+ taskContext.taskMemoryManager(),
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ new KVComparator(ordering, keySchema.length()),
+ prefixComparator,
+ /* initialSize */ 4096,
+ pageSizeBytes,
+ inMemSorter);
+
+ sorter.spill();
+ map.free();
+ }
+ }
+
+ /**
+ * Inserts a key-value record into the sorter. If the sorter no longer has enough memory to hold
+ * the record, the sorter sorts the existing records in-memory, writes them out as partially
+ * sorted runs, and then reallocates memory to hold the new record.
+ */
+ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
+ final long prefix = prefixComputer.computePrefix(key);
+ sorter.insertKVRecord(
+ key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
+ value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
+ }
+
+ public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException {
+ try {
+ final UnsafeSorterIterator underlying = sorter.getSortedIterator();
+ if (!underlying.hasNext()) {
+ // Since we won't ever call next() on an empty iterator, we need to clean up resources
+ // here in order to prevent memory leaks.
+ cleanupResources();
+ }
+
+ return new KVIterator<UnsafeRow, UnsafeRow>() {
+ private UnsafeRow key = new UnsafeRow();
+ private UnsafeRow value = new UnsafeRow();
+ private int numKeyFields = keySchema.size();
+ private int numValueFields = valueSchema.size();
+
+ @Override
+ public boolean next() throws IOException {
+ try {
+ if (underlying.hasNext()) {
+ underlying.loadNext();
+
+ Object baseObj = underlying.getBaseObject();
+ long recordOffset = underlying.getBaseOffset();
+ int recordLen = underlying.getRecordLength();
+
+ // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
+ int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
+ int valueLen = recordLen - keyLen - 4;
+
+ key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
+ value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
+
+ return true;
+ } else {
+ key = null;
+ value = null;
+ cleanupResources();
+ return false;
+ }
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+ @Override
+ public UnsafeRow getKey() {
+ return key;
+ }
+
+ @Override
+ public UnsafeRow getValue() {
+ return value;
+ }
+
+ @Override
+ public void close() {
+ cleanupResources();
+ }
+ };
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+ /**
+ * Marks the current page as no-more-space-available, and as a result, either allocate a
+ * new page or spill when we see the next record.
+ */
+ @VisibleForTesting
+ void closeCurrentPage() {
+ sorter.closeCurrentPage();
+ }
+
+ private void cleanupResources() {
+ sorter.freeMemory();
+ }
+
+ private static final class KVComparator extends RecordComparator {
+ private final BaseOrdering ordering;
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+ private final int numKeyFields;
+
+ public KVComparator(BaseOrdering ordering, int numKeyFields) {
+ this.numKeyFields = numKeyFields;
+ this.ordering = ordering;
+ }
+
+ @Override
+ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+ // Note that since ordering doesn't need the total length of the record, we just pass -1
+ // into the row.
+ row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
+ row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
+ return ordering.compare(row1, row2);
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 2e870ec8ae..49adf21537 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -50,17 +50,36 @@ object SortPrefixUtils {
}
}
+ /**
+ * Creates the prefix comparator for the first field in the given schema, in ascending order.
+ */
def getPrefixComparator(schema: StructType): PrefixComparator = {
- val field = schema.head
- getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
+ if (schema.nonEmpty) {
+ val field = schema.head
+ getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
+ } else {
+ new PrefixComparator {
+ override def compare(prefix1: Long, prefix2: Long): Int = 0
+ }
+ }
}
+ /**
+ * Creates the prefix computer for the first field in the given schema, in ascending order.
+ */
def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
- val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
- val prefixProjection = UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending)))
- new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
+ if (schema.nonEmpty) {
+ val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
+ val prefixProjection = UnsafeProjection.create(
+ SortPrefix(SortOrder(boundReference, Ascending)))
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
+ }
+ }
+ } else {
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = 0
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 6d903ab23c..92cf328c76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -116,6 +116,7 @@ case class TungstenSort(
protected override def doExecute(): RDD[InternalRow] = {
val schema = child.schema
val childOutput = child.output
+ val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
child.execute().mapPartitions({ iter =>
val ordering = newOrdering(sortOrder, childOutput)
@@ -131,7 +132,8 @@ case class TungstenSort(
}
}
- val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
+ val sorter = new UnsafeExternalRowSorter(
+ schema, ordering, prefixComparator, prefixComputer, pageSize)
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
new file mode 100644
index 0000000000..53de2d0f07
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.shuffle.ShuffleMemoryManager
+
+/**
+ * A [[ShuffleMemoryManager]] that can be controlled to run out of memory.
+ */
+class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue) {
+ private var oom = false
+
+ override def tryToAcquire(numBytes: Long): Long = {
+ if (oom) {
+ oom = false
+ 0
+ } else {
+ // Uncomment the following to trace memory allocations.
+ // println(s"tryToAcquire $numBytes in " +
+ // Thread.currentThread().getStackTrace.mkString("", "\n -", ""))
+ val acquired = super.tryToAcquire(numBytes)
+ acquired
+ }
+ }
+
+ override def release(numBytes: Long): Unit = {
+ // Uncomment the following to trace memory releases.
+ // println(s"release $numBytes in " +
+ // Thread.currentThread().getStackTrace.mkString("", "\n -", ""))
+ super.release(numBytes)
+ }
+
+ def markAsOutOfMemory(): Unit = {
+ oom = true
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 098bdd0017..4c94b3307d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -17,24 +17,26 @@
package org.apache.spark.sql.execution
-import org.scalatest.{BeforeAndAfterEach, Matchers}
-
-import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
import scala.collection.mutable
-import scala.util.Random
+import scala.util.{Try, Random}
+
+import org.scalatest.Matchers
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
-
-class UnsafeFixedWidthAggregationMapSuite
- extends SparkFunSuite
- with Matchers
- with BeforeAndAfterEach {
+/**
+ * Test suite for [[UnsafeFixedWidthAggregationMap]].
+ *
+ * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases.
+ */
+class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
import UnsafeFixedWidthAggregationMap._
@@ -44,23 +46,40 @@ class UnsafeFixedWidthAggregationMapSuite
private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
private var taskMemoryManager: TaskMemoryManager = null
- private var shuffleMemoryManager: ShuffleMemoryManager = null
+ private var shuffleMemoryManager: TestShuffleMemoryManager = null
+
+ def testWithMemoryLeakDetection(name: String)(f: => Unit) {
+ def cleanup(): Unit = {
+ if (taskMemoryManager != null) {
+ val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+ assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
+ assert(leakedShuffleMemory === 0)
+ taskMemoryManager = null
+ }
+ }
- override def beforeEach(): Unit = {
- taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
- shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue)
+ test(name) {
+ taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ shuffleMemoryManager = new TestShuffleMemoryManager
+ try {
+ f
+ } catch {
+ case NonFatal(e) =>
+ Try(cleanup())
+ throw e
+ }
+ cleanup()
+ }
}
- override def afterEach(): Unit = {
- if (taskMemoryManager != null) {
- val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
- assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
- assert(leakedShuffleMemory === 0)
- taskMemoryManager = null
- }
+ private def randomStrings(n: Int): Seq[String] = {
+ val rand = new Random(42)
+ Seq.fill(512) {
+ Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+ }.distinct
}
- test("supported schemas") {
+ testWithMemoryLeakDetection("supported schemas") {
assert(supportsAggregationBufferSchema(
StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
assert(!supportsAggregationBufferSchema(
@@ -70,7 +89,7 @@ class UnsafeFixedWidthAggregationMapSuite
!supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
}
- test("empty map") {
+ testWithMemoryLeakDetection("empty map") {
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
@@ -85,7 +104,7 @@ class UnsafeFixedWidthAggregationMapSuite
map.free()
}
- test("updating values for a single key") {
+ testWithMemoryLeakDetection("updating values for a single key") {
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
@@ -113,7 +132,7 @@ class UnsafeFixedWidthAggregationMapSuite
map.free()
}
- test("inserting large random keys") {
+ testWithMemoryLeakDetection("inserting large random keys") {
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
@@ -140,7 +159,21 @@ class UnsafeFixedWidthAggregationMapSuite
map.free()
}
- test("test sorting") {
+ testWithMemoryLeakDetection("test external sorting") {
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ metricsSystem = null))
+
+ // Memory consumption in the beginning of the task.
+ val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
@@ -152,26 +185,47 @@ class UnsafeFixedWidthAggregationMapSuite
false // disable perf metrics
)
- val rand = new Random(42)
- val groupKeys: Set[String] = Seq.fill(512) {
- Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
- }.toSet
- groupKeys.foreach { keyString =>
+ val keys = randomStrings(1024).take(512)
+ keys.foreach { keyString =>
val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
buf.setInt(0, keyString.length)
assert(buf != null)
}
+ // Convert the map into a sorter
+ val sorter = map.destructAndCreateExternalSorter()
+
+ withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
+ // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter.
+ assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() ===
+ initialMemoryConsumption + 4096 * 16)
+ }
+
+ // Add more keys to the sorter and make sure the results come out sorted.
+ val additionalKeys = randomStrings(1024)
+ val keyConverter = UnsafeProjection.create(groupKeySchema)
+ val valueConverter = UnsafeProjection.create(aggBufferSchema)
+
+ additionalKeys.zipWithIndex.foreach { case (str, i) =>
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(str.length)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+
+ if ((i % 100) == 0) {
+ shuffleMemoryManager.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
+ }
+
val out = new scala.collection.mutable.ArrayBuffer[String]
- val iter = map.sortedIterator()
+ val iter = sorter.sortedIterator()
while (iter.next()) {
assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
out += iter.getKey.getString(0)
}
- assert(out === groupKeys.toSeq.sorted)
+ assert(out === (keys ++ additionalKeys).sorted)
map.free()
}
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
new file mode 100644
index 0000000000..5d214d7bfc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection}
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark._
+
+class UnsafeKVExternalSorterSuite extends SparkFunSuite {
+
+ test("sorting string key and int int value") {
+
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val shuffleMemMgr = new TestShuffleMemoryManager
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemMgr,
+ metricsSystem = null))
+
+ val keySchema = new StructType().add("a", StringType)
+ val valueSchema = new StructType().add("b", IntegerType).add("c", IntegerType)
+ val sorter = new UnsafeKVExternalSorter(
+ keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
+ 16 * 1024)
+
+ val keyConverter = UnsafeProjection.create(keySchema)
+ val valueConverter = UnsafeProjection.create(valueSchema)
+
+ val rand = new Random(42)
+ val data = null +: Seq.fill[String](10) {
+ Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+ }
+
+ val inputRows = data.map { str =>
+ keyConverter.apply(InternalRow(UTF8String.fromString(str))).copy()
+ }
+
+ var i = 0
+ data.foreach { str =>
+ if (str != null) {
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(str.length, str.length + 1)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+ } else {
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(-1, -2)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+ }
+
+ if ((i % 100) == 0) {
+ shuffleMemMgr.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
+ i += 1
+ }
+
+ val out = new scala.collection.mutable.ArrayBuffer[InternalRow]
+ val iter = sorter.sortedIterator()
+ while (iter.next()) {
+ if (iter.getKey.getUTF8String(0) == null) {
+ withClue(s"for null key") {
+ assert(-1 === iter.getValue.getInt(0))
+ assert(-2 === iter.getValue.getInt(1))
+ }
+ } else {
+ val key = iter.getKey.getString(0)
+ withClue(s"for key $key") {
+ assert(key.length === iter.getValue.getInt(0))
+ assert(key.length + 1 === iter.getValue.getInt(1))
+ }
+ }
+ out += iter.getKey.copy()
+ }
+
+ assert(out === inputRows.sorted(RowOrdering.forSchema(keySchema.map(_.dataType))))
+ }
+
+ test("sorting arbitrary string data") {
+
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val shuffleMemMgr = new TestShuffleMemoryManager
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemMgr,
+ metricsSystem = null))
+
+ val keySchema = new StructType().add("a", StringType)
+ val valueSchema = new StructType().add("b", IntegerType)
+ val sorter = new UnsafeKVExternalSorter(
+ keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
+ 16 * 1024)
+
+ val keyConverter = UnsafeProjection.create(keySchema)
+ val valueConverter = UnsafeProjection.create(valueSchema)
+
+ val rand = new Random(42)
+ val data = Seq.fill(512) {
+ Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+ }
+
+ var i = 0
+ data.foreach { str =>
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(str.length)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+
+ if ((i % 100) == 0) {
+ shuffleMemMgr.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
+ i += 1
+ }
+
+ val out = new scala.collection.mutable.ArrayBuffer[String]
+ val iter = sorter.sortedIterator()
+ while (iter.next()) {
+ assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
+ out += iter.getKey.getString(0)
+ }
+
+ assert(out === data.sorted)
+ }
+}