aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java234
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala140
4 files changed, 401 insertions, 6 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
new file mode 100644
index 0000000000..66012e3c94
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
+ *
+ * This map supports a maximum of 2 billion keys.
+ */
+public final class UnsafeFixedWidthAggregationMap {
+
+ /**
+ * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
+ * map, we copy this buffer and use it as the value.
+ */
+ private final byte[] emptyAggregationBuffer;
+
+ private final StructType aggregationBufferSchema;
+
+ private final StructType groupingKeySchema;
+
+ /**
+ * Encodes grouping keys as UnsafeRows.
+ */
+ private final UnsafeProjection groupingKeyProjection;
+
+ /**
+ * A hashmap which maps from opaque bytearray keys to bytearray values.
+ */
+ private final BytesToBytesMap map;
+
+ /**
+ * Re-used pointer to the current aggregation buffer
+ */
+ private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+
+ private final boolean enablePerfMetrics;
+
+ /**
+ * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
+ * schema, false otherwise.
+ */
+ public static boolean supportsAggregationBufferSchema(StructType schema) {
+ for (StructField field: schema.fields()) {
+ if (field.dataType() instanceof DecimalType) {
+ DecimalType dt = (DecimalType) field.dataType();
+ if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
+ return false;
+ }
+ } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Create a new UnsafeFixedWidthAggregationMap.
+ *
+ * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
+ * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
+ * @param groupingKeySchema the schema of the grouping key, used for row conversion.
+ * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with
+ * other tasks.
+ * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
+ * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
+ * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
+ */
+ public UnsafeFixedWidthAggregationMap(
+ InternalRow emptyAggregationBuffer,
+ StructType aggregationBufferSchema,
+ StructType groupingKeySchema,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ int initialCapacity,
+ long pageSizeBytes,
+ boolean enablePerfMetrics) {
+ this.aggregationBufferSchema = aggregationBufferSchema;
+ this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
+ this.groupingKeySchema = groupingKeySchema;
+ this.map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+ this.enablePerfMetrics = enablePerfMetrics;
+
+ // Initialize the buffer for aggregation value
+ final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
+ this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
+ assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
+ UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
+ }
+
+ /**
+ * Return the aggregation buffer for the current group. For efficiency, all calls to this method
+ * return the same object. If additional memory could not be allocated, then this method will
+ * signal an error by returning null.
+ */
+ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
+ final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
+
+ // Probe our map using the serialized key
+ final BytesToBytesMap.Location loc = map.lookup(
+ unsafeGroupingKeyRow.getBaseObject(),
+ unsafeGroupingKeyRow.getBaseOffset(),
+ unsafeGroupingKeyRow.getSizeInBytes());
+ if (!loc.isDefined()) {
+ // This is the first time that we've seen this grouping key, so we'll insert a copy of the
+ // empty aggregation buffer into the map:
+ boolean putSucceeded = loc.putNewKey(
+ unsafeGroupingKeyRow.getBaseObject(),
+ unsafeGroupingKeyRow.getBaseOffset(),
+ unsafeGroupingKeyRow.getSizeInBytes(),
+ emptyAggregationBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ emptyAggregationBuffer.length
+ );
+ if (!putSucceeded) {
+ return null;
+ }
+ }
+
+ // Reset the pointer to point to the value that we just stored or looked up:
+ final MemoryLocation address = loc.getValueAddress();
+ currentAggregationBuffer.pointTo(
+ address.getBaseObject(),
+ address.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
+ );
+ return currentAggregationBuffer;
+ }
+
+ /**
+ * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
+ */
+ public static class MapEntry {
+ private MapEntry() { };
+ public final UnsafeRow key = new UnsafeRow();
+ public final UnsafeRow value = new UnsafeRow();
+ }
+
+ /**
+ * Returns an iterator over the keys and values in this map.
+ *
+ * For efficiency, each call returns the same object.
+ */
+ public Iterator<MapEntry> iterator() {
+ return new Iterator<MapEntry>() {
+
+ private final MapEntry entry = new MapEntry();
+ private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
+
+ @Override
+ public boolean hasNext() {
+ return mapLocationIterator.hasNext();
+ }
+
+ @Override
+ public MapEntry next() {
+ final BytesToBytesMap.Location loc = mapLocationIterator.next();
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ entry.key.pointTo(
+ keyAddress.getBaseObject(),
+ keyAddress.getBaseOffset(),
+ groupingKeySchema.length(),
+ loc.getKeyLength()
+ );
+ entry.value.pointTo(
+ valueAddress.getBaseObject(),
+ valueAddress.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
+ );
+ return entry;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Free the unsafe memory associated with this map.
+ */
+ public void free() {
+ map.free();
+ }
+
+ @SuppressWarnings("UseOfSystemOutOrSystemErr")
+ public void printPerfMetrics() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException("Perf metrics not enabled");
+ }
+ System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
+ System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
+ System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
+ System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index d851eae3fc..469de6ca8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.io.IOException
+
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
@@ -266,6 +268,7 @@ case class GeneratedAggregate(
aggregationBufferSchema,
groupKeySchema,
TaskContext.get.taskMemoryManager(),
+ SparkEnv.get.shuffleMemoryManager,
1024 * 16, // initial capacity
pageSizeBytes,
false // disable tracking of performance metrics
@@ -275,6 +278,9 @@ case class GeneratedAggregate(
val currentRow: InternalRow = iter.next()
val groupKey: InternalRow = groupProjection(currentRow)
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
+ if (aggregationBuffer == null) {
+ throw new IOException("Could not allocate memory to grow aggregation buffer")
+ }
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index f88a45f48a..cc8bbfd2f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.joins
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
+import org.apache.spark.shuffle.ShuffleMemoryManager
import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
@@ -217,7 +219,7 @@ private[joins] final class UnsafeHashedRelation(
}
}
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(hashTable.size())
val iter = hashTable.entrySet().iterator()
@@ -256,16 +258,26 @@ private[joins] final class UnsafeHashedRelation(
}
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val nKeys = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
- val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+
+ // Dummy shuffle memory manager which always grants all memory allocation requests.
+ // We use this because it doesn't make sense count shared broadcast variables' memory usage
+ // towards individual tasks' quotas. In the future, we should devise a better way of handling
+ // this.
+ val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) {
+ override def tryToAcquire(numBytes: Long): Long = numBytes
+ override def release(numBytes: Long): Unit = {}
+ }
val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
.getSizeAsBytes("spark.buffer.pageSize", "64m")
binaryMap = new BytesToBytesMap(
- memoryManager,
+ taskMemoryManager,
+ shuffleMemoryManager,
nKeys * 2, // reduce hash collision
pageSizeBytes)
@@ -287,8 +299,11 @@ private[joins] final class UnsafeHashedRelation(
// put it into binary map
val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
assert(!loc.isDefined, "Duplicated key found!")
- loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
+ val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
+ if (!putSuceeded) {
+ throw new IOException("Could not allocate memory to grow BytesToBytesMap")
+ }
i += 1
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
new file mode 100644
index 0000000000..79fd52dacd
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.unsafe.types.UTF8String
+
+
+class UnsafeFixedWidthAggregationMapSuite
+ extends SparkFunSuite
+ with Matchers
+ with BeforeAndAfterEach {
+
+ import UnsafeFixedWidthAggregationMap._
+
+ private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
+ private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
+ private def emptyAggregationBuffer: InternalRow = InternalRow(0)
+ private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
+
+ private var taskMemoryManager: TaskMemoryManager = null
+ private var shuffleMemoryManager: ShuffleMemoryManager = null
+
+ override def beforeEach(): Unit = {
+ taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue)
+ }
+
+ override def afterEach(): Unit = {
+ if (taskMemoryManager != null) {
+ val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask
+ assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
+ assert(leakedShuffleMemory === 0)
+ taskMemoryManager = null
+ }
+ }
+
+ test("supported schemas") {
+ assert(supportsAggregationBufferSchema(
+ StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
+ assert(!supportsAggregationBufferSchema(
+ StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil)))
+ assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
+ assert(
+ !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
+ }
+
+ test("empty map") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 1024, // initial capacity,
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+ assert(!map.iterator().hasNext)
+ map.free()
+ }
+
+ test("updating values for a single key") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 1024, // initial capacity
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+ val groupKey = InternalRow(UTF8String.fromString("cats"))
+
+ // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
+ assert(map.getAggregationBuffer(groupKey) != null)
+ val iter = map.iterator()
+ val entry = iter.next()
+ assert(!iter.hasNext)
+ entry.key.getString(0) should be ("cats")
+ entry.value.getInt(0) should be (0)
+
+ // Modifications to rows retrieved from the map should update the values in the map
+ entry.value.setInt(0, 42)
+ map.getAggregationBuffer(groupKey).getInt(0) should be (42)
+
+ map.free()
+ }
+
+ test("inserting large random keys") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 128, // initial capacity
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+ val rand = new Random(42)
+ val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
+ groupKeys.foreach { keyString =>
+ assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null)
+ }
+ val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
+ entry.key.getString(0)
+ }.toSet
+ seenKeys.size should be (groupKeys.size)
+ seenKeys should be (groupKeys)
+
+ map.free()
+ }
+
+}