diff options
Diffstat (limited to 'sql/core')
4 files changed, 401 insertions, 6 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java new file mode 100644 index 0000000000..66012e3c94 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; +import java.util.Iterator; + +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. + * + * This map supports a maximum of 2 billion keys. + */ +public final class UnsafeFixedWidthAggregationMap { + + /** + * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the + * map, we copy this buffer and use it as the value. + */ + private final byte[] emptyAggregationBuffer; + + private final StructType aggregationBufferSchema; + + private final StructType groupingKeySchema; + + /** + * Encodes grouping keys as UnsafeRows. + */ + private final UnsafeProjection groupingKeyProjection; + + /** + * A hashmap which maps from opaque bytearray keys to bytearray values. + */ + private final BytesToBytesMap map; + + /** + * Re-used pointer to the current aggregation buffer + */ + private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + + private final boolean enablePerfMetrics; + + /** + * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given + * schema, false otherwise. + */ + public static boolean supportsAggregationBufferSchema(StructType schema) { + for (StructField field: schema.fields()) { + if (field.dataType() instanceof DecimalType) { + DecimalType dt = (DecimalType) field.dataType(); + if (dt.precision() > Decimal.MAX_LONG_DIGITS()) { + return false; + } + } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * Create a new UnsafeFixedWidthAggregationMap. + * + * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) + * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. + * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with + * other tasks. + * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. + * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) + */ + public UnsafeFixedWidthAggregationMap( + InternalRow emptyAggregationBuffer, + StructType aggregationBufferSchema, + StructType groupingKeySchema, + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + int initialCapacity, + long pageSizeBytes, + boolean enablePerfMetrics) { + this.aggregationBufferSchema = aggregationBufferSchema; + this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); + this.groupingKeySchema = groupingKeySchema; + this.map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); + this.enablePerfMetrics = enablePerfMetrics; + + // Initialize the buffer for aggregation value + final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); + this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + + UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); + } + + /** + * Return the aggregation buffer for the current group. For efficiency, all calls to this method + * return the same object. If additional memory could not be allocated, then this method will + * signal an error by returning null. + */ + public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { + final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); + + // Probe our map using the serialized key + final BytesToBytesMap.Location loc = map.lookup( + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes()); + if (!loc.isDefined()) { + // This is the first time that we've seen this grouping key, so we'll insert a copy of the + // empty aggregation buffer into the map: + boolean putSucceeded = loc.putNewKey( + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes(), + emptyAggregationBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + emptyAggregationBuffer.length + ); + if (!putSucceeded) { + return null; + } + } + + // Reset the pointer to point to the value that we just stored or looked up: + final MemoryLocation address = loc.getValueAddress(); + currentAggregationBuffer.pointTo( + address.getBaseObject(), + address.getBaseOffset(), + aggregationBufferSchema.length(), + loc.getValueLength() + ); + return currentAggregationBuffer; + } + + /** + * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}. + */ + public static class MapEntry { + private MapEntry() { }; + public final UnsafeRow key = new UnsafeRow(); + public final UnsafeRow value = new UnsafeRow(); + } + + /** + * Returns an iterator over the keys and values in this map. + * + * For efficiency, each call returns the same object. + */ + public Iterator<MapEntry> iterator() { + return new Iterator<MapEntry>() { + + private final MapEntry entry = new MapEntry(); + private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator(); + + @Override + public boolean hasNext() { + return mapLocationIterator.hasNext(); + } + + @Override + public MapEntry next() { + final BytesToBytesMap.Location loc = mapLocationIterator.next(); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + entry.key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + groupingKeySchema.length(), + loc.getKeyLength() + ); + entry.value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + aggregationBufferSchema.length(), + loc.getValueLength() + ); + return entry; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Free the unsafe memory associated with this map. + */ + public void free() { + map.free(); + } + + @SuppressWarnings("UseOfSystemOutOrSystemErr") + public void printPerfMetrics() { + if (!enablePerfMetrics) { + throw new IllegalStateException("Perf metrics not enabled"); + } + System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); + System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); + System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); + System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index d851eae3fc..469de6ca8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.io.IOException + import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD @@ -266,6 +268,7 @@ case class GeneratedAggregate( aggregationBufferSchema, groupKeySchema, TaskContext.get.taskMemoryManager(), + SparkEnv.get.shuffleMemoryManager, 1024 * 16, // initial capacity pageSizeBytes, false // disable tracking of performance metrics @@ -275,6 +278,9 @@ case class GeneratedAggregate( val currentRow: InternalRow = iter.next() val groupKey: InternalRow = groupProjection(currentRow) val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) + if (aggregationBuffer == null) { + throw new IOException("Could not allocate memory to grow aggregation buffer") + } updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index f88a45f48a..cc8bbfd2f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.joins -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} +import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer @@ -217,7 +219,7 @@ private[joins] final class UnsafeHashedRelation( } } - override def writeExternal(out: ObjectOutput): Unit = { + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeInt(hashTable.size()) val iter = hashTable.entrySet().iterator() @@ -256,16 +258,26 @@ private[joins] final class UnsafeHashedRelation( } } - override def readExternal(in: ObjectInput): Unit = { + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { val nKeys = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory - val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + + // Dummy shuffle memory manager which always grants all memory allocation requests. + // We use this because it doesn't make sense count shared broadcast variables' memory usage + // towards individual tasks' quotas. In the future, we should devise a better way of handling + // this. + val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) { + override def tryToAcquire(numBytes: Long): Long = numBytes + override def release(numBytes: Long): Unit = {} + } val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) .getSizeAsBytes("spark.buffer.pageSize", "64m") binaryMap = new BytesToBytesMap( - memoryManager, + taskMemoryManager, + shuffleMemoryManager, nKeys * 2, // reduce hash collision pageSizeBytes) @@ -287,8 +299,11 @@ private[joins] final class UnsafeHashedRelation( // put it into binary map val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) assert(!loc.isDefined, "Duplicated key found!") - loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, + val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + if (!putSuceeded) { + throw new IOException("Could not allocate memory to grow BytesToBytesMap") + } i += 1 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala new file mode 100644 index 0000000000..79fd52dacd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.types.UTF8String + + +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach { + + import UnsafeFixedWidthAggregationMap._ + + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) + private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyAggregationBuffer: InternalRow = InternalRow(0) + private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes + + private var taskMemoryManager: TaskMemoryManager = null + private var shuffleMemoryManager: ShuffleMemoryManager = null + + override def beforeEach(): Unit = { + taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue) + } + + override def afterEach(): Unit = { + if (taskMemoryManager != null) { + val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask + assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) + assert(leakedShuffleMemory === 0) + taskMemoryManager = null + } + } + + test("supported schemas") { + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + + test("empty map") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 1024, // initial capacity, + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + assert(!map.iterator().hasNext) + map.free() + } + + test("updating values for a single key") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 1024, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + val groupKey = InternalRow(UTF8String.fromString("cats")) + + // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) + assert(map.getAggregationBuffer(groupKey) != null) + val iter = map.iterator() + val entry = iter.next() + assert(!iter.hasNext) + entry.key.getString(0) should be ("cats") + entry.value.getInt(0) should be (0) + + // Modifications to rows retrieved from the map should update the values in the map + entry.value.setInt(0, 42) + map.getAggregationBuffer(groupKey).getInt(0) should be (42) + + map.free() + } + + test("inserting large random keys") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + val rand = new Random(42) + val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet + groupKeys.foreach { keyString => + assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null) + } + val seenKeys: Set[String] = map.iterator().asScala.map { entry => + entry.key.getString(0) + }.toSet + seenKeys.size should be (groupKeys.size) + seenKeys should be (groupKeys) + + map.free() + } + +} |