diff options
Diffstat (limited to 'sql')
10 files changed, 586 insertions, 140 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1b475b2492..b4fc0b7b70 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -507,7 +507,8 @@ public final class UnsafeRow extends MutableRow { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { - build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); + build.append(java.lang.Long.toHexString( + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i))); build.append(','); } build.append(']'); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 68c49feae9..5e4c6232c9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -59,20 +59,21 @@ final class UnsafeExternalRowSorter { StructType schema, Ordering<InternalRow> ordering, PrefixComparator prefixComparator, - PrefixComputer prefixComputer) throws IOException { + PrefixComputer prefixComputer, + long pageSizeBytes) throws IOException { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); - sorter = new UnsafeExternalSorter( + sorter = UnsafeExternalSorter.create( taskContext.taskMemoryManager(), sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, new RowComparator(ordering, schema.length()), prefixComparator, - 4096, - sparkEnv.conf() + /* initialSize */ 4096, + pageSizeBytes ); } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index be0966641b..349007789f 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -106,6 +106,11 @@ <artifactId>parquet-avro</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index a0a8dd5154..9e2c9334a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,24 +19,18 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; -import org.apache.spark.util.collection.unsafe.sort.RecordComparator; -import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter; -import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -215,7 +209,7 @@ public final class UnsafeFixedWidthAggregationMap { } /** - * Free the unsafe memory associated with this map. + * Free the memory associated with this map. This is idempotent and can be called multiple times. */ public void free() { map.free(); @@ -233,92 +227,17 @@ public final class UnsafeFixedWidthAggregationMap { } /** - * Sorts the key, value data in this map in place, and return them as an iterator. + * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]] + * that can be used to insert more records to do external sorting. * * The only memory that is allocated is the address/prefix array, 16 bytes per record. + * + * Note that this destroys the map, and as a result, the map cannot be used anymore after this. */ - public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() { - int numElements = map.numElements(); - final int numKeyFields = groupingKeySchema.size(); - TaskMemoryManager memoryManager = map.getTaskMemoryManager(); - - UnsafeExternalRowSorter.PrefixComputer prefixComp = - SortPrefixUtils.createPrefixGenerator(groupingKeySchema); - PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(groupingKeySchema); - - final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema); - RecordComparator recordComparator = new RecordComparator() { - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); - - @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); - row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); - return ordering.compare(row1, row2); - } - }; - - // Insert the records into the in-memory sorter. - final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - memoryManager, recordComparator, prefixComparator, numElements); - - BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); - UnsafeRow row = new UnsafeRow(); - while (iter.hasNext()) { - final BytesToBytesMap.Location loc = iter.next(); - final Object baseObject = loc.getKeyAddress().getBaseObject(); - final long baseOffset = loc.getKeyAddress().getBaseOffset(); - - // Get encoded memory address - MemoryBlock page = loc.getMemoryPage(); - long address = memoryManager.encodePageNumberAndOffset(page, baseOffset - 8); - - // Compute prefix - row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); - final long prefix = prefixComp.computePrefix(row); - - sorter.insertRecord(address, prefix); - } - - // Return the sorted result as an iterator. - return new KVIterator<UnsafeRow, UnsafeRow>() { - - private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); - private final UnsafeRow key = new UnsafeRow(); - private final UnsafeRow value = new UnsafeRow(); - private int numValueFields = aggregationBufferSchema.size(); - - @Override - public boolean next() throws IOException { - if (sortedIterator.hasNext()) { - sortedIterator.loadNext(); - Object baseObj = sortedIterator.getBaseObject(); - long recordOffset = sortedIterator.getBaseOffset(); - int recordLen = sortedIterator.getRecordLength(); - int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); - value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, recordLen - keyLen); - return true; - } else { - return false; - } - } - - @Override - public UnsafeRow getKey() { - return key; - } - - @Override - public UnsafeRow getValue() { - return value; - } - - @Override - public void close() { - // Do nothing - } - }; + public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { + UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter( + groupingKeySchema, aggregationBufferSchema, + SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map); + return sorter; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java new file mode 100644 index 0000000000..f6b0176863 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; + +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.TaskContext; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.KVIterator; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.collection.unsafe.sort.*; + +/** + * A class for performing external sorting on key-value records. Both key and value are UnsafeRows. + * + * Note that this class allows optionally passing in a {@link BytesToBytesMap} directly in order + * to perform in-place sorting of records in the map. + */ +public final class UnsafeKVExternalSorter { + + private final StructType keySchema; + private final StructType valueSchema; + private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; + private final UnsafeExternalSorter sorter; + + public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, + BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes) + throws IOException { + this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null); + } + + public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, + BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes, + @Nullable BytesToBytesMap map) throws IOException { + this.keySchema = keySchema; + this.valueSchema = valueSchema; + final TaskContext taskContext = TaskContext.get(); + + prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema); + PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); + BaseOrdering ordering = GenerateOrdering.create(keySchema); + KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); + + TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager(); + + if (map == null) { + sorter = UnsafeExternalSorter.create( + taskMemoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + /* initialSize */ 4096, + pageSizeBytes); + } else { + // Insert the records into the in-memory sorter. + final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( + taskMemoryManager, recordComparator, prefixComparator, map.numElements()); + + final int numKeyFields = keySchema.size(); + BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); + UnsafeRow row = new UnsafeRow(); + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + final Object baseObject = loc.getKeyAddress().getBaseObject(); + final long baseOffset = loc.getKeyAddress().getBaseOffset(); + + // Get encoded memory address + // baseObject + baseOffset point to the beginning of the key data in the map, but that + // the KV-pair's length data is stored in the word immediately before that address + MemoryBlock page = loc.getMemoryPage(); + long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8); + + // Compute prefix + row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); + final long prefix = prefixComputer.computePrefix(row); + + inMemSorter.insertRecord(address, prefix); + } + + sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( + taskContext.taskMemoryManager(), + shuffleMemoryManager, + blockManager, + taskContext, + new KVComparator(ordering, keySchema.length()), + prefixComparator, + /* initialSize */ 4096, + pageSizeBytes, + inMemSorter); + + sorter.spill(); + map.free(); + } + } + + /** + * Inserts a key-value record into the sorter. If the sorter no longer has enough memory to hold + * the record, the sorter sorts the existing records in-memory, writes them out as partially + * sorted runs, and then reallocates memory to hold the new record. + */ + public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException { + final long prefix = prefixComputer.computePrefix(key); + sorter.insertKVRecord( + key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(), + value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); + } + + public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException { + try { + final UnsafeSorterIterator underlying = sorter.getSortedIterator(); + if (!underlying.hasNext()) { + // Since we won't ever call next() on an empty iterator, we need to clean up resources + // here in order to prevent memory leaks. + cleanupResources(); + } + + return new KVIterator<UnsafeRow, UnsafeRow>() { + private UnsafeRow key = new UnsafeRow(); + private UnsafeRow value = new UnsafeRow(); + private int numKeyFields = keySchema.size(); + private int numValueFields = valueSchema.size(); + + @Override + public boolean next() throws IOException { + try { + if (underlying.hasNext()) { + underlying.loadNext(); + + Object baseObj = underlying.getBaseObject(); + long recordOffset = underlying.getBaseOffset(); + int recordLen = underlying.getRecordLength(); + + // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) + int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); + int valueLen = recordLen - keyLen - 4; + + key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); + value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); + + return true; + } else { + key = null; + value = null; + cleanupResources(); + return false; + } + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + + @Override + public UnsafeRow getKey() { + return key; + } + + @Override + public UnsafeRow getValue() { + return value; + } + + @Override + public void close() { + cleanupResources(); + } + }; + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + + /** + * Marks the current page as no-more-space-available, and as a result, either allocate a + * new page or spill when we see the next record. + */ + @VisibleForTesting + void closeCurrentPage() { + sorter.closeCurrentPage(); + } + + private void cleanupResources() { + sorter.freeMemory(); + } + + private static final class KVComparator extends RecordComparator { + private final BaseOrdering ordering; + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + private final int numKeyFields; + + public KVComparator(BaseOrdering ordering, int numKeyFields) { + this.numKeyFields = numKeyFields; + this.ordering = ordering; + } + + @Override + public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + // Note that since ordering doesn't need the total length of the record, we just pass -1 + // into the row. + row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); + row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); + return ordering.compare(row1, row2); + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2e870ec8ae..49adf21537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -50,17 +50,36 @@ object SortPrefixUtils { } } + /** + * Creates the prefix comparator for the first field in the given schema, in ascending order. + */ def getPrefixComparator(schema: StructType): PrefixComparator = { - val field = schema.head - getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending)) + if (schema.nonEmpty) { + val field = schema.head + getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending)) + } else { + new PrefixComparator { + override def compare(prefix1: Long, prefix2: Long): Int = 0 + } + } } + /** + * Creates the prefix computer for the first field in the given schema, in ascending order. + */ def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = { - val boundReference = BoundReference(0, schema.head.dataType, nullable = true) - val prefixProjection = UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending))) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) + if (schema.nonEmpty) { + val boundReference = BoundReference(0, schema.head.dataType, nullable = true) + val prefixProjection = UnsafeProjection.create( + SortPrefix(SortOrder(boundReference, Ascending))) + new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + } else { + new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = 0 } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 6d903ab23c..92cf328c76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -116,6 +116,7 @@ case class TungstenSort( protected override def doExecute(): RDD[InternalRow] = { val schema = child.schema val childOutput = child.output + val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") child.execute().mapPartitions({ iter => val ordering = newOrdering(sortOrder, childOutput) @@ -131,7 +132,8 @@ case class TungstenSort( } } - val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala new file mode 100644 index 0000000000..53de2d0f07 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.shuffle.ShuffleMemoryManager + +/** + * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. + */ +class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue) { + private var oom = false + + override def tryToAcquire(numBytes: Long): Long = { + if (oom) { + oom = false + 0 + } else { + // Uncomment the following to trace memory allocations. + // println(s"tryToAcquire $numBytes in " + + // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) + val acquired = super.tryToAcquire(numBytes) + acquired + } + } + + override def release(numBytes: Long): Unit = { + // Uncomment the following to trace memory releases. + // println(s"release $numBytes in " + + // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) + super.release(numBytes) + } + + def markAsOutOfMemory(): Unit = { + oom = true + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 098bdd0017..4c94b3307d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -17,24 +17,26 @@ package org.apache.spark.sql.execution -import org.scalatest.{BeforeAndAfterEach, Matchers} - -import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import scala.collection.mutable -import scala.util.Random +import scala.util.{Try, Random} + +import org.scalatest.Matchers -import org.apache.spark.SparkFunSuite -import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String - -class UnsafeFixedWidthAggregationMapSuite - extends SparkFunSuite - with Matchers - with BeforeAndAfterEach { +/** + * Test suite for [[UnsafeFixedWidthAggregationMap]]. + * + * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. + */ +class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { import UnsafeFixedWidthAggregationMap._ @@ -44,23 +46,40 @@ class UnsafeFixedWidthAggregationMapSuite private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes private var taskMemoryManager: TaskMemoryManager = null - private var shuffleMemoryManager: ShuffleMemoryManager = null + private var shuffleMemoryManager: TestShuffleMemoryManager = null + + def testWithMemoryLeakDetection(name: String)(f: => Unit) { + def cleanup(): Unit = { + if (taskMemoryManager != null) { + val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() + assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) + assert(leakedShuffleMemory === 0) + taskMemoryManager = null + } + } - override def beforeEach(): Unit = { - taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue) + test(name) { + taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + shuffleMemoryManager = new TestShuffleMemoryManager + try { + f + } catch { + case NonFatal(e) => + Try(cleanup()) + throw e + } + cleanup() + } } - override def afterEach(): Unit = { - if (taskMemoryManager != null) { - val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() - assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) - assert(leakedShuffleMemory === 0) - taskMemoryManager = null - } + private def randomStrings(n: Int): Seq[String] = { + val rand = new Random(42) + Seq.fill(512) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + }.distinct } - test("supported schemas") { + testWithMemoryLeakDetection("supported schemas") { assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema( @@ -70,7 +89,7 @@ class UnsafeFixedWidthAggregationMapSuite !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) } - test("empty map") { + testWithMemoryLeakDetection("empty map") { val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -85,7 +104,7 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } - test("updating values for a single key") { + testWithMemoryLeakDetection("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -113,7 +132,7 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } - test("inserting large random keys") { + testWithMemoryLeakDetection("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -140,7 +159,21 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } - test("test sorting") { + testWithMemoryLeakDetection("test external sorting") { + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + metricsSystem = null)) + + // Memory consumption in the beginning of the task. + val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -152,26 +185,47 @@ class UnsafeFixedWidthAggregationMapSuite false // disable perf metrics ) - val rand = new Random(42) - val groupKeys: Set[String] = Seq.fill(512) { - Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString - }.toSet - groupKeys.foreach { keyString => + val keys = randomStrings(1024).take(512) + keys.foreach { keyString => val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) buf.setInt(0, keyString.length) assert(buf != null) } + // Convert the map into a sorter + val sorter = map.destructAndCreateExternalSorter() + + withClue(s"destructAndCreateExternalSorter should release memory used by the map") { + // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === + initialMemoryConsumption + 4096 * 16) + } + + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + val keyConverter = UnsafeProjection.create(groupKeySchema) + val valueConverter = UnsafeProjection.create(aggBufferSchema) + + additionalKeys.zipWithIndex.foreach { case (str, i) => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + val out = new scala.collection.mutable.ArrayBuffer[String] - val iter = map.sortedIterator() + val iter = sorter.sortedIterator() while (iter.next()) { assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) out += iter.getKey.getString(0) } - assert(out === groupKeys.toSeq.sorted) + assert(out === (keys ++ additionalKeys).sorted) map.free() } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala new file mode 100644 index 0000000000..5d214d7bfc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark._ + +class UnsafeKVExternalSorterSuite extends SparkFunSuite { + + test("sorting string key and int int value") { + + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val shuffleMemMgr = new TestShuffleMemoryManager + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null)) + + val keySchema = new StructType().add("a", StringType) + val valueSchema = new StructType().add("b", IntegerType).add("c", IntegerType) + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, + 16 * 1024) + + val keyConverter = UnsafeProjection.create(keySchema) + val valueConverter = UnsafeProjection.create(valueSchema) + + val rand = new Random(42) + val data = null +: Seq.fill[String](10) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + } + + val inputRows = data.map { str => + keyConverter.apply(InternalRow(UTF8String.fromString(str))).copy() + } + + var i = 0 + data.foreach { str => + if (str != null) { + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length, str.length + 1) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + } else { + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(-1, -2) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + } + + if ((i % 100) == 0) { + shuffleMemMgr.markAsOutOfMemory() + sorter.closeCurrentPage() + } + i += 1 + } + + val out = new scala.collection.mutable.ArrayBuffer[InternalRow] + val iter = sorter.sortedIterator() + while (iter.next()) { + if (iter.getKey.getUTF8String(0) == null) { + withClue(s"for null key") { + assert(-1 === iter.getValue.getInt(0)) + assert(-2 === iter.getValue.getInt(1)) + } + } else { + val key = iter.getKey.getString(0) + withClue(s"for key $key") { + assert(key.length === iter.getValue.getInt(0)) + assert(key.length + 1 === iter.getValue.getInt(1)) + } + } + out += iter.getKey.copy() + } + + assert(out === inputRows.sorted(RowOrdering.forSchema(keySchema.map(_.dataType)))) + } + + test("sorting arbitrary string data") { + + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val shuffleMemMgr = new TestShuffleMemoryManager + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null)) + + val keySchema = new StructType().add("a", StringType) + val valueSchema = new StructType().add("b", IntegerType) + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, + 16 * 1024) + + val keyConverter = UnsafeProjection.create(keySchema) + val valueConverter = UnsafeProjection.create(valueSchema) + + val rand = new Random(42) + val data = Seq.fill(512) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + } + + var i = 0 + data.foreach { str => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemMgr.markAsOutOfMemory() + sorter.closeCurrentPage() + } + i += 1 + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) + out += iter.getKey.getString(0) + } + + assert(out === data.sorted) + } +} |