diff options
7 files changed, 196 insertions, 15 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 481375f493..cf222b7272 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -23,6 +23,8 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import javax.annotation.Nullable; + import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -217,6 +219,7 @@ public final class BytesToBytesMap { private final Iterator<MemoryBlock> dataPagesIterator; private final Location loc; + private MemoryBlock currentPage; private int currentRecordNumber = 0; private Object pageBaseObject; private long offsetInPage; @@ -232,7 +235,7 @@ public final class BytesToBytesMap { } private void advanceToNextPage() { - final MemoryBlock currentPage = dataPagesIterator.next(); + currentPage = dataPagesIterator.next(); pageBaseObject = currentPage.getBaseObject(); offsetInPage = currentPage.getBaseOffset(); } @@ -249,7 +252,7 @@ public final class BytesToBytesMap { advanceToNextPage(); totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); } - loc.with(pageBaseObject, offsetInPage); + loc.with(currentPage, offsetInPage); offsetInPage += 8 + totalLength; currentRecordNumber++; return loc; @@ -346,14 +349,19 @@ public final class BytesToBytesMap { private int keyLength; private int valueLength; + /** + * Memory page containing the record. Only set if created by {@link BytesToBytesMap#iterator()}. + */ + @Nullable private MemoryBlock memoryPage; + private void updateAddressesAndSizes(long fullKeyAddress) { updateAddressesAndSizes( taskMemoryManager.getPage(fullKeyAddress), taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object page, final long keyOffsetInPage) { - long position = keyOffsetInPage; + private void updateAddressesAndSizes(final Object page, final long offsetInPage) { + long position = offsetInPage; final int totalLength = PlatformDependent.UNSAFE.getInt(page, position); position += 4; keyLength = PlatformDependent.UNSAFE.getInt(page, position); @@ -366,7 +374,7 @@ public final class BytesToBytesMap { valueMemoryLocation.setObjAndOffset(page, position); } - Location with(int pos, int keyHashcode, boolean isDefined) { + private Location with(int pos, int keyHashcode, boolean isDefined) { this.pos = pos; this.isDefined = isDefined; this.keyHashcode = keyHashcode; @@ -377,13 +385,22 @@ public final class BytesToBytesMap { return this; } - Location with(Object page, long keyOffsetInPage) { + private Location with(MemoryBlock page, long offsetInPage) { this.isDefined = true; - updateAddressesAndSizes(page, keyOffsetInPage); + this.memoryPage = page; + updateAddressesAndSizes(page.getBaseObject(), offsetInPage); return this; } /** + * Returns the memory page that contains the current record. + * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. + */ + public MemoryBlock getMemoryPage() { + return this.memoryPage; + } + + /** * Returns true if the key is defined at this position, and false otherwise. */ public boolean isDefined() { @@ -538,7 +555,7 @@ public final class BytesToBytesMap { long insertCursor = dataPageInsertOffset; // Compute all of our offsets up-front: - final long totalLengthOffset = insertCursor; + final long recordOffset = insertCursor; insertCursor += 4; final long keyLengthOffset = insertCursor; insertCursor += 4; @@ -547,7 +564,7 @@ public final class BytesToBytesMap { final long valueDataOffsetInPage = insertCursor; insertCursor += valueLengthBytes; // word used to store the value size - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, totalLengthOffset, + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset, keyLengthBytes + valueLengthBytes); PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key @@ -569,7 +586,7 @@ public final class BytesToBytesMap { numElements++; bitset.set(pos); final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( - dataPage, totalLengthOffset); + dataPage, recordOffset); longArray.set(pos * 2, storedKeyAddress); longArray.set(pos * 2 + 1, keyHashcode); updateAddressesAndSizes(storedKeyAddress); @@ -618,6 +635,10 @@ public final class BytesToBytesMap { assert(dataPages.isEmpty()); } + public TaskMemoryManager getTaskMemoryManager() { + return taskMemoryManager; + } + /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ public long getTotalMemoryConsumption() { long totalDataPagesSize = 0L; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index d79325aea8..000be70f17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -125,6 +125,8 @@ object UnsafeProjection { GenerateUnsafeProjection.generate(exprs) } + def create(expr: Expression): UnsafeProjection = create(Seq(expr)) + /** * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to * `inputSchema`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index dbd4616d28..cc848aa199 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -21,6 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Private import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StructType /** * Inherits some default implementation for Java from `Ordering[Row]` @@ -43,7 +44,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = in.map(BindReferences.bindReference(_, inputSchema)) - protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { + /** + * Creates a code gen ordering for sorting this schema, in ascending order. + */ + def create(schema: StructType): BaseOrdering = { + create(schema.zipWithIndex.map { case (field, ordinal) => + SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending) + }) + } + + protected def create(ordering: Seq[SortOrder]): BaseOrdering = { val ctx = newCodeGenContext() val comparisons = ordering.map { order => diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c18b6dea6b..a0a8dd5154 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -17,19 +17,26 @@ package org.apache.spark.sql.execution; +import java.io.IOException; + import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; +import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -225,4 +232,93 @@ public final class UnsafeFixedWidthAggregationMap { System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); } + /** + * Sorts the key, value data in this map in place, and return them as an iterator. + * + * The only memory that is allocated is the address/prefix array, 16 bytes per record. + */ + public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() { + int numElements = map.numElements(); + final int numKeyFields = groupingKeySchema.size(); + TaskMemoryManager memoryManager = map.getTaskMemoryManager(); + + UnsafeExternalRowSorter.PrefixComputer prefixComp = + SortPrefixUtils.createPrefixGenerator(groupingKeySchema); + PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(groupingKeySchema); + + final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema); + RecordComparator recordComparator = new RecordComparator() { + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + + @Override + public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); + row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); + return ordering.compare(row1, row2); + } + }; + + // Insert the records into the in-memory sorter. + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( + memoryManager, recordComparator, prefixComparator, numElements); + + BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); + UnsafeRow row = new UnsafeRow(); + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + final Object baseObject = loc.getKeyAddress().getBaseObject(); + final long baseOffset = loc.getKeyAddress().getBaseOffset(); + + // Get encoded memory address + MemoryBlock page = loc.getMemoryPage(); + long address = memoryManager.encodePageNumberAndOffset(page, baseOffset - 8); + + // Compute prefix + row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); + final long prefix = prefixComp.computePrefix(row); + + sorter.insertRecord(address, prefix); + } + + // Return the sorted result as an iterator. + return new KVIterator<UnsafeRow, UnsafeRow>() { + + private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); + private final UnsafeRow key = new UnsafeRow(); + private final UnsafeRow value = new UnsafeRow(); + private int numValueFields = aggregationBufferSchema.size(); + + @Override + public boolean next() throws IOException { + if (sortedIterator.hasNext()) { + sortedIterator.loadNext(); + Object baseObj = sortedIterator.getBaseObject(); + long recordOffset = sortedIterator.getBaseOffset(); + int recordLen = sortedIterator.getRecordLength(); + int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); + key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); + value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, recordLen - keyLen); + return true; + } else { + return false; + } + } + + @Override + public UnsafeRow getKey() { + return key; + } + + @Override + public UnsafeRow getValue() { + return value; + } + + @Override + public void close() { + // Do nothing + } + }; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index a2145b185c..17d4166af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -46,4 +47,19 @@ object SortPrefixUtils { case _ => NoOpPrefixComparator } } + + def getPrefixComparator(schema: StructType): PrefixComparator = { + val field = schema.head + getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending)) + } + + def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = { + val boundReference = BoundReference(0, schema.head.dataType, nullable = true) + val prefixProjection = UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending))) + new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 6a2c51ca88..098bdd0017 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -140,4 +140,38 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } + test("test sorting") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + val rand = new Random(42) + val groupKeys: Set[String] = Seq.fill(512) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + }.toSet + groupKeys.foreach { keyString => + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) + buf.setInt(0, keyString.length) + assert(buf != null) + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = map.sortedIterator() + while (iter.next()) { + assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) + out += iter.getKey.getString(0) + } + + assert(out === groupKeys.toSeq.sorted) + + map.free() + } + } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java index fb163401c0..5c9d5d9a38 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java @@ -17,9 +17,11 @@ package org.apache.spark.unsafe; +import java.io.IOException; + public abstract class KVIterator<K, V> { - public abstract boolean next(); + public abstract boolean next() throws IOException; public abstract K getKey(); |