aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java2
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java6
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java112
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala30
4 files changed, 65 insertions, 85 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index bf5f965a9d..dec7fcfa0d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -428,7 +428,7 @@ public final class UnsafeExternalSorter {
public UnsafeSorterIterator getSortedIterator() throws IOException {
assert(inMemSorter != null);
- final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
+ final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
if (spillWriters.isEmpty()) {
return inMemoryIterator;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 3131465391..1e4b8a116e 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -133,7 +133,7 @@ public final class UnsafeInMemorySorter {
pointerArrayInsertPosition++;
}
- private static final class SortedIterator extends UnsafeSorterIterator {
+ public static final class SortedIterator extends UnsafeSorterIterator {
private final TaskMemoryManager memoryManager;
private final int sortBufferInsertPosition;
@@ -144,7 +144,7 @@ public final class UnsafeInMemorySorter {
private long keyPrefix;
private int recordLength;
- SortedIterator(
+ private SortedIterator(
TaskMemoryManager memoryManager,
int sortBufferInsertPosition,
long[] sortBuffer) {
@@ -186,7 +186,7 @@ public final class UnsafeInMemorySorter {
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
* {@code next()} will return the same mutable object.
*/
- public UnsafeSorterIterator getSortedIterator() {
+ public SortedIterator getSortedIterator() {
sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index f6b0176863..312ec8ea0d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -134,7 +134,7 @@ public final class UnsafeKVExternalSorter {
value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
}
- public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException {
+ public KVSorterIterator sortedIterator() throws IOException {
try {
final UnsafeSorterIterator underlying = sorter.getSortedIterator();
if (!underlying.hasNext()) {
@@ -142,58 +142,7 @@ public final class UnsafeKVExternalSorter {
// here in order to prevent memory leaks.
cleanupResources();
}
-
- return new KVIterator<UnsafeRow, UnsafeRow>() {
- private UnsafeRow key = new UnsafeRow();
- private UnsafeRow value = new UnsafeRow();
- private int numKeyFields = keySchema.size();
- private int numValueFields = valueSchema.size();
-
- @Override
- public boolean next() throws IOException {
- try {
- if (underlying.hasNext()) {
- underlying.loadNext();
-
- Object baseObj = underlying.getBaseObject();
- long recordOffset = underlying.getBaseOffset();
- int recordLen = underlying.getRecordLength();
-
- // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
- int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
- int valueLen = recordLen - keyLen - 4;
-
- key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
- value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
-
- return true;
- } else {
- key = null;
- value = null;
- cleanupResources();
- return false;
- }
- } catch (IOException e) {
- cleanupResources();
- throw e;
- }
- }
-
- @Override
- public UnsafeRow getKey() {
- return key;
- }
-
- @Override
- public UnsafeRow getValue() {
- return value;
- }
-
- @Override
- public void close() {
- cleanupResources();
- }
- };
+ return new KVSorterIterator(underlying);
} catch (IOException e) {
cleanupResources();
throw e;
@@ -233,4 +182,61 @@ public final class UnsafeKVExternalSorter {
return ordering.compare(row1, row2);
}
}
+
+ public class KVSorterIterator extends KVIterator<UnsafeRow, UnsafeRow> {
+ private UnsafeRow key = new UnsafeRow();
+ private UnsafeRow value = new UnsafeRow();
+ private final int numKeyFields = keySchema.size();
+ private final int numValueFields = valueSchema.size();
+ private final UnsafeSorterIterator underlying;
+
+ private KVSorterIterator(UnsafeSorterIterator underlying) {
+ this.underlying = underlying;
+ }
+
+ @Override
+ public boolean next() throws IOException {
+ try {
+ if (underlying.hasNext()) {
+ underlying.loadNext();
+
+ Object baseObj = underlying.getBaseObject();
+ long recordOffset = underlying.getBaseOffset();
+ int recordLen = underlying.getRecordLength();
+
+ // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
+ int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
+ int valueLen = recordLen - keyLen - 4;
+
+ key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
+ value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
+
+ return true;
+ } else {
+ key = null;
+ value = null;
+ cleanupResources();
+ return false;
+ }
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+ @Override
+ public UnsafeRow getKey() {
+ return key;
+ }
+
+ @Override
+ public UnsafeRow getValue() {
+ return value;
+ }
+
+ @Override
+ public void close() {
+ cleanupResources();
+ }
+ };
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
index 37d34eb7cc..b465787fe8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
@@ -17,12 +17,12 @@
package org.apache.spark.sql.execution.aggregate
-import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.types.StructType
/**
@@ -230,7 +230,7 @@ class UnsafeHybridAggregationIterator(
}
// Step 5: Get the sorted iterator from the externalSorter.
- val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator()
+ val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator()
// Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
// For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
@@ -368,31 +368,5 @@ object UnsafeHybridAggregationIterator {
newMutableProjection,
outputsUnsafeRows)
}
-
- def createFromKVIterator(
- groupingKeyAttributes: Seq[Attribute],
- valueAttributes: Seq[Attribute],
- inputKVIterator: KVIterator[UnsafeRow, InternalRow],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
- nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- initialInputBufferOffset: Int,
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
- new UnsafeHybridAggregationIterator(
- groupingKeyAttributes,
- valueAttributes,
- inputKVIterator,
- nonCompleteAggregateExpressions,
- nonCompleteAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- newMutableProjection,
- outputsUnsafeRows)
- }
// scalastyle:on
}