diff options
Diffstat (limited to 'sql')
4 files changed, 97 insertions, 60 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java new file mode 100644 index 0000000000..59c774da74 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.KVIterator; + +public abstract class UnsafeKeyValueSorter { + + public abstract void insert(UnsafeRow key, UnsafeRow value); + + public abstract KVIterator<UnsafeRow, UnsafeRow> sort() throws IOException; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 08a98cdd94..c18b6dea6b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution; -import java.io.IOException; -import java.util.Iterator; - import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; @@ -28,6 +25,7 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -157,53 +155,54 @@ public final class UnsafeFixedWidthAggregationMap { } /** - * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}. - */ - public static class MapEntry { - private MapEntry() { }; - public final UnsafeRow key = new UnsafeRow(); - public final UnsafeRow value = new UnsafeRow(); - } - - /** * Returns an iterator over the keys and values in this map. * * For efficiency, each call returns the same object. */ - public Iterator<MapEntry> iterator() { - return new Iterator<MapEntry>() { + public KVIterator<UnsafeRow, UnsafeRow> iterator() { + return new KVIterator<UnsafeRow, UnsafeRow>() { + + private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator(); + private final UnsafeRow key = new UnsafeRow(); + private final UnsafeRow value = new UnsafeRow(); - private final MapEntry entry = new MapEntry(); - private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator(); + @Override + public boolean next() { + if (mapLocationIterator.hasNext()) { + final BytesToBytesMap.Location loc = mapLocationIterator.next(); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + groupingKeySchema.length(), + loc.getKeyLength() + ); + value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + aggregationBufferSchema.length(), + loc.getValueLength() + ); + return true; + } else { + return false; + } + } @Override - public boolean hasNext() { - return mapLocationIterator.hasNext(); + public UnsafeRow getKey() { + return key; } @Override - public MapEntry next() { - final BytesToBytesMap.Location loc = mapLocationIterator.next(); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); - entry.key.pointTo( - keyAddress.getBaseObject(), - keyAddress.getBaseOffset(), - groupingKeySchema.length(), - loc.getKeyLength() - ); - entry.value.pointTo( - valueAddress.getBaseObject(), - valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), - loc.getValueLength() - ); - return entry; + public UnsafeRow getValue() { + return value; } @Override - public void remove() { - throw new UnsupportedOperationException(); + public void close() { + // Do nothing. } }; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 469de6ca8e..cd87b8deba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -287,21 +287,26 @@ case class GeneratedAggregate( new Iterator[InternalRow] { private[this] val mapIterator = aggregationMap.iterator() private[this] val resultProjection = resultProjectionBuilder() + private[this] var _hasNext = mapIterator.next() - def hasNext: Boolean = mapIterator.hasNext + def hasNext: Boolean = _hasNext def next(): InternalRow = { - val entry = mapIterator.next() - val result = resultProjection(joinedRow(entry.key, entry.value)) - if (hasNext) { - result + if (_hasNext) { + val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue)) + _hasNext = mapIterator.next() + if (_hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } } else { - // This is the last element in the iterator, so let's free the buffer. Before we do, - // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory - val resultCopy = result.copy() - aggregationMap.free() - resultCopy + throw new java.util.NoSuchElementException } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 79fd52dacd..6a2c51ca88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.scalatest.{BeforeAndAfterEach, Matchers} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random import org.apache.spark.SparkFunSuite @@ -52,7 +53,7 @@ class UnsafeFixedWidthAggregationMapSuite override def afterEach(): Unit = { if (taskMemoryManager != null) { - val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask + val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) assert(leakedShuffleMemory === 0) taskMemoryManager = null @@ -80,7 +81,7 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES, false // disable perf metrics ) - assert(!map.iterator().hasNext) + assert(!map.iterator().next()) map.free() } @@ -100,13 +101,13 @@ class UnsafeFixedWidthAggregationMapSuite // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) assert(map.getAggregationBuffer(groupKey) != null) val iter = map.iterator() - val entry = iter.next() - assert(!iter.hasNext) - entry.key.getString(0) should be ("cats") - entry.value.getInt(0) should be (0) + assert(iter.next()) + iter.getKey.getString(0) should be ("cats") + iter.getValue.getInt(0) should be (0) + assert(!iter.next()) // Modifications to rows retrieved from the map should update the values in the map - entry.value.setInt(0, 42) + iter.getValue.setInt(0, 42) map.getAggregationBuffer(groupKey).getInt(0) should be (42) map.free() @@ -128,12 +129,14 @@ class UnsafeFixedWidthAggregationMapSuite groupKeys.foreach { keyString => assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null) } - val seenKeys: Set[String] = map.iterator().asScala.map { entry => - entry.key.getString(0) - }.toSet - seenKeys.size should be (groupKeys.size) - seenKeys should be (groupKeys) + val seenKeys = new mutable.HashSet[String] + val iter = map.iterator() + while (iter.next()) { + seenKeys += iter.getKey.getString(0) + } + assert(seenKeys.size === groupKeys.size) + assert(seenKeys === groupKeys) map.free() } |