aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala12
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java100
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala34
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java4
7 files changed, 196 insertions, 15 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 481375f493..cf222b7272 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -23,6 +23,8 @@ import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
+import javax.annotation.Nullable;
+
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -217,6 +219,7 @@ public final class BytesToBytesMap {
private final Iterator<MemoryBlock> dataPagesIterator;
private final Location loc;
+ private MemoryBlock currentPage;
private int currentRecordNumber = 0;
private Object pageBaseObject;
private long offsetInPage;
@@ -232,7 +235,7 @@ public final class BytesToBytesMap {
}
private void advanceToNextPage() {
- final MemoryBlock currentPage = dataPagesIterator.next();
+ currentPage = dataPagesIterator.next();
pageBaseObject = currentPage.getBaseObject();
offsetInPage = currentPage.getBaseOffset();
}
@@ -249,7 +252,7 @@ public final class BytesToBytesMap {
advanceToNextPage();
totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage);
}
- loc.with(pageBaseObject, offsetInPage);
+ loc.with(currentPage, offsetInPage);
offsetInPage += 8 + totalLength;
currentRecordNumber++;
return loc;
@@ -346,14 +349,19 @@ public final class BytesToBytesMap {
private int keyLength;
private int valueLength;
+ /**
+ * Memory page containing the record. Only set if created by {@link BytesToBytesMap#iterator()}.
+ */
+ @Nullable private MemoryBlock memoryPage;
+
private void updateAddressesAndSizes(long fullKeyAddress) {
updateAddressesAndSizes(
taskMemoryManager.getPage(fullKeyAddress),
taskMemoryManager.getOffsetInPage(fullKeyAddress));
}
- private void updateAddressesAndSizes(final Object page, final long keyOffsetInPage) {
- long position = keyOffsetInPage;
+ private void updateAddressesAndSizes(final Object page, final long offsetInPage) {
+ long position = offsetInPage;
final int totalLength = PlatformDependent.UNSAFE.getInt(page, position);
position += 4;
keyLength = PlatformDependent.UNSAFE.getInt(page, position);
@@ -366,7 +374,7 @@ public final class BytesToBytesMap {
valueMemoryLocation.setObjAndOffset(page, position);
}
- Location with(int pos, int keyHashcode, boolean isDefined) {
+ private Location with(int pos, int keyHashcode, boolean isDefined) {
this.pos = pos;
this.isDefined = isDefined;
this.keyHashcode = keyHashcode;
@@ -377,13 +385,22 @@ public final class BytesToBytesMap {
return this;
}
- Location with(Object page, long keyOffsetInPage) {
+ private Location with(MemoryBlock page, long offsetInPage) {
this.isDefined = true;
- updateAddressesAndSizes(page, keyOffsetInPage);
+ this.memoryPage = page;
+ updateAddressesAndSizes(page.getBaseObject(), offsetInPage);
return this;
}
/**
+ * Returns the memory page that contains the current record.
+ * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}.
+ */
+ public MemoryBlock getMemoryPage() {
+ return this.memoryPage;
+ }
+
+ /**
* Returns true if the key is defined at this position, and false otherwise.
*/
public boolean isDefined() {
@@ -538,7 +555,7 @@ public final class BytesToBytesMap {
long insertCursor = dataPageInsertOffset;
// Compute all of our offsets up-front:
- final long totalLengthOffset = insertCursor;
+ final long recordOffset = insertCursor;
insertCursor += 4;
final long keyLengthOffset = insertCursor;
insertCursor += 4;
@@ -547,7 +564,7 @@ public final class BytesToBytesMap {
final long valueDataOffsetInPage = insertCursor;
insertCursor += valueLengthBytes; // word used to store the value size
- PlatformDependent.UNSAFE.putInt(dataPageBaseObject, totalLengthOffset,
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset,
keyLengthBytes + valueLengthBytes);
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
// Copy the key
@@ -569,7 +586,7 @@ public final class BytesToBytesMap {
numElements++;
bitset.set(pos);
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
- dataPage, totalLengthOffset);
+ dataPage, recordOffset);
longArray.set(pos * 2, storedKeyAddress);
longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
@@ -618,6 +635,10 @@ public final class BytesToBytesMap {
assert(dataPages.isEmpty());
}
+ public TaskMemoryManager getTaskMemoryManager() {
+ return taskMemoryManager;
+ }
+
/** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
public long getTotalMemoryConsumption() {
long totalDataPagesSize = 0L;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index d79325aea8..000be70f17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -125,6 +125,8 @@ object UnsafeProjection {
GenerateUnsafeProjection.generate(exprs)
}
+ def create(expr: Expression): UnsafeProjection = create(Seq(expr))
+
/**
* Returns an UnsafeProjection for given sequence of Expressions, which will be bound to
* `inputSchema`.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index dbd4616d28..cc848aa199 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -21,6 +21,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Private
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.StructType
/**
* Inherits some default implementation for Java from `Ordering[Row]`
@@ -43,7 +44,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
in.map(BindReferences.bindReference(_, inputSchema))
- protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = {
+ /**
+ * Creates a code gen ordering for sorting this schema, in ascending order.
+ */
+ def create(schema: StructType): BaseOrdering = {
+ create(schema.zipWithIndex.map { case (field, ordinal) =>
+ SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending)
+ })
+ }
+
+ protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
val ctx = newCodeGenContext()
val comparisons = ordering.map { order =>
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index c18b6dea6b..a0a8dd5154 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -17,19 +17,26 @@
package org.apache.spark.sql.execution;
+import java.io.IOException;
+
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
+import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -225,4 +232,93 @@ public final class UnsafeFixedWidthAggregationMap {
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
}
+ /**
+ * Sorts the key, value data in this map in place, and return them as an iterator.
+ *
+ * The only memory that is allocated is the address/prefix array, 16 bytes per record.
+ */
+ public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() {
+ int numElements = map.numElements();
+ final int numKeyFields = groupingKeySchema.size();
+ TaskMemoryManager memoryManager = map.getTaskMemoryManager();
+
+ UnsafeExternalRowSorter.PrefixComputer prefixComp =
+ SortPrefixUtils.createPrefixGenerator(groupingKeySchema);
+ PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(groupingKeySchema);
+
+ final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema);
+ RecordComparator recordComparator = new RecordComparator() {
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+
+ @Override
+ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+ row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
+ row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
+ return ordering.compare(row1, row2);
+ }
+ };
+
+ // Insert the records into the in-memory sorter.
+ final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
+ memoryManager, recordComparator, prefixComparator, numElements);
+
+ BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+ UnsafeRow row = new UnsafeRow();
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ final Object baseObject = loc.getKeyAddress().getBaseObject();
+ final long baseOffset = loc.getKeyAddress().getBaseOffset();
+
+ // Get encoded memory address
+ MemoryBlock page = loc.getMemoryPage();
+ long address = memoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
+
+ // Compute prefix
+ row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
+ final long prefix = prefixComp.computePrefix(row);
+
+ sorter.insertRecord(address, prefix);
+ }
+
+ // Return the sorted result as an iterator.
+ return new KVIterator<UnsafeRow, UnsafeRow>() {
+
+ private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
+ private final UnsafeRow key = new UnsafeRow();
+ private final UnsafeRow value = new UnsafeRow();
+ private int numValueFields = aggregationBufferSchema.size();
+
+ @Override
+ public boolean next() throws IOException {
+ if (sortedIterator.hasNext()) {
+ sortedIterator.loadNext();
+ Object baseObj = sortedIterator.getBaseObject();
+ long recordOffset = sortedIterator.getBaseOffset();
+ int recordLen = sortedIterator.getRecordLength();
+ int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
+ key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
+ value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, recordLen - keyLen);
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public UnsafeRow getKey() {
+ return key;
+ }
+
+ @Override
+ public UnsafeRow getValue() {
+ return value;
+ }
+
+ @Override
+ public void close() {
+ // Do nothing
+ }
+ };
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index a2145b185c..17d4166af5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
@@ -46,4 +47,19 @@ object SortPrefixUtils {
case _ => NoOpPrefixComparator
}
}
+
+ def getPrefixComparator(schema: StructType): PrefixComparator = {
+ val field = schema.head
+ getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
+ }
+
+ def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
+ val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
+ val prefixProjection = UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending)))
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 6a2c51ca88..098bdd0017 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -140,4 +140,38 @@ class UnsafeFixedWidthAggregationMapSuite
map.free()
}
+ test("test sorting") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 128, // initial capacity
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+
+ val rand = new Random(42)
+ val groupKeys: Set[String] = Seq.fill(512) {
+ Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+ }.toSet
+ groupKeys.foreach { keyString =>
+ val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
+ buf.setInt(0, keyString.length)
+ assert(buf != null)
+ }
+
+ val out = new scala.collection.mutable.ArrayBuffer[String]
+ val iter = map.sortedIterator()
+ while (iter.next()) {
+ assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
+ out += iter.getKey.getString(0)
+ }
+
+ assert(out === groupKeys.toSeq.sorted)
+
+ map.free()
+ }
+
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
index fb163401c0..5c9d5d9a38 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
@@ -17,9 +17,11 @@
package org.apache.spark.unsafe;
+import java.io.IOException;
+
public abstract class KVIterator<K, V> {
- public abstract boolean next();
+ public abstract boolean next() throws IOException;
public abstract K getKey();