aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-31 19:19:27 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-07-31 19:19:27 -0700
commit8cb415a4b9bc1f82127ccce4a5579d433f4e8f83 (patch)
treef9071996c6485e77700463106ae86c1a346508c9 /sql/core
parentf51fd6fbb4d9822502f98b312251e317d757bc3a (diff)
downloadspark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.tar.gz
spark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.tar.bz2
spark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.zip
[SPARK-9451] [SQL] Support entries larger than default page size in BytesToBytesMap & integrate with ShuffleMemoryManager
This patch adds support for entries larger than the default page size in BytesToBytesMap. These large rows are handled by allocating special overflow pages to hold individual entries. In addition, this patch integrates BytesToBytesMap with the ShuffleMemoryManager: - Move BytesToBytesMap from `unsafe` to `core` so that it can import `ShuffleMemoryManager`. - Before allocating new data pages, ask the ShuffleMemoryManager to reserve the memory: - `putNewKey()` now returns a boolean to indicate whether the insert succeeded or failed due to a lack of memory. The caller can use this value to respond to the memory pressure (e.g. by spilling). - `UnsafeFixedWidthAggregationMap. getAggregationBuffer()` now returns `null` to signal failure due to a lack of memory. - Updated all uses of these classes to handle these error conditions. - Added new tests for allocating large records and for allocations which fail due to memory pressure. - Extended the `afterAll()` test teardown methods to detect ShuffleMemoryManager leaks. Author: Josh Rosen <joshrosen@databricks.com> Closes #7762 from JoshRosen/large-rows and squashes the following commits: ae7bc56 [Josh Rosen] Fix compilation 82fc657 [Josh Rosen] Merge remote-tracking branch 'origin/master' into large-rows 34ab943 [Josh Rosen] Remove semi 31a525a [Josh Rosen] Integrate BytesToBytesMap with ShuffleMemoryManager. 626b33c [Josh Rosen] Move code to sql/core and spark/core packages so that ShuffleMemoryManager can be integrated ec4484c [Josh Rosen] Move BytesToBytesMap from unsafe package to core. 642ed69 [Josh Rosen] Rename size to numElements bea1152 [Josh Rosen] Add basic test. 2cd3570 [Josh Rosen] Remove accidental duplicated code 07ff9ef [Josh Rosen] Basic support for large rows in BytesToBytesMap.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java234
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala140
4 files changed, 401 insertions, 6 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
new file mode 100644
index 0000000000..66012e3c94
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
+ *
+ * This map supports a maximum of 2 billion keys.
+ */
+public final class UnsafeFixedWidthAggregationMap {
+
+ /**
+ * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
+ * map, we copy this buffer and use it as the value.
+ */
+ private final byte[] emptyAggregationBuffer;
+
+ private final StructType aggregationBufferSchema;
+
+ private final StructType groupingKeySchema;
+
+ /**
+ * Encodes grouping keys as UnsafeRows.
+ */
+ private final UnsafeProjection groupingKeyProjection;
+
+ /**
+ * A hashmap which maps from opaque bytearray keys to bytearray values.
+ */
+ private final BytesToBytesMap map;
+
+ /**
+ * Re-used pointer to the current aggregation buffer
+ */
+ private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+
+ private final boolean enablePerfMetrics;
+
+ /**
+ * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
+ * schema, false otherwise.
+ */
+ public static boolean supportsAggregationBufferSchema(StructType schema) {
+ for (StructField field: schema.fields()) {
+ if (field.dataType() instanceof DecimalType) {
+ DecimalType dt = (DecimalType) field.dataType();
+ if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
+ return false;
+ }
+ } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Create a new UnsafeFixedWidthAggregationMap.
+ *
+ * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
+ * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
+ * @param groupingKeySchema the schema of the grouping key, used for row conversion.
+ * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with
+ * other tasks.
+ * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
+ * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
+ * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
+ */
+ public UnsafeFixedWidthAggregationMap(
+ InternalRow emptyAggregationBuffer,
+ StructType aggregationBufferSchema,
+ StructType groupingKeySchema,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ int initialCapacity,
+ long pageSizeBytes,
+ boolean enablePerfMetrics) {
+ this.aggregationBufferSchema = aggregationBufferSchema;
+ this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
+ this.groupingKeySchema = groupingKeySchema;
+ this.map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+ this.enablePerfMetrics = enablePerfMetrics;
+
+ // Initialize the buffer for aggregation value
+ final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
+ this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
+ assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
+ UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
+ }
+
+ /**
+ * Return the aggregation buffer for the current group. For efficiency, all calls to this method
+ * return the same object. If additional memory could not be allocated, then this method will
+ * signal an error by returning null.
+ */
+ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
+ final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
+
+ // Probe our map using the serialized key
+ final BytesToBytesMap.Location loc = map.lookup(
+ unsafeGroupingKeyRow.getBaseObject(),
+ unsafeGroupingKeyRow.getBaseOffset(),
+ unsafeGroupingKeyRow.getSizeInBytes());
+ if (!loc.isDefined()) {
+ // This is the first time that we've seen this grouping key, so we'll insert a copy of the
+ // empty aggregation buffer into the map:
+ boolean putSucceeded = loc.putNewKey(
+ unsafeGroupingKeyRow.getBaseObject(),
+ unsafeGroupingKeyRow.getBaseOffset(),
+ unsafeGroupingKeyRow.getSizeInBytes(),
+ emptyAggregationBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ emptyAggregationBuffer.length
+ );
+ if (!putSucceeded) {
+ return null;
+ }
+ }
+
+ // Reset the pointer to point to the value that we just stored or looked up:
+ final MemoryLocation address = loc.getValueAddress();
+ currentAggregationBuffer.pointTo(
+ address.getBaseObject(),
+ address.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
+ );
+ return currentAggregationBuffer;
+ }
+
+ /**
+ * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
+ */
+ public static class MapEntry {
+ private MapEntry() { };
+ public final UnsafeRow key = new UnsafeRow();
+ public final UnsafeRow value = new UnsafeRow();
+ }
+
+ /**
+ * Returns an iterator over the keys and values in this map.
+ *
+ * For efficiency, each call returns the same object.
+ */
+ public Iterator<MapEntry> iterator() {
+ return new Iterator<MapEntry>() {
+
+ private final MapEntry entry = new MapEntry();
+ private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
+
+ @Override
+ public boolean hasNext() {
+ return mapLocationIterator.hasNext();
+ }
+
+ @Override
+ public MapEntry next() {
+ final BytesToBytesMap.Location loc = mapLocationIterator.next();
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ entry.key.pointTo(
+ keyAddress.getBaseObject(),
+ keyAddress.getBaseOffset(),
+ groupingKeySchema.length(),
+ loc.getKeyLength()
+ );
+ entry.value.pointTo(
+ valueAddress.getBaseObject(),
+ valueAddress.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
+ );
+ return entry;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Free the unsafe memory associated with this map.
+ */
+ public void free() {
+ map.free();
+ }
+
+ @SuppressWarnings("UseOfSystemOutOrSystemErr")
+ public void printPerfMetrics() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException("Perf metrics not enabled");
+ }
+ System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
+ System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
+ System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
+ System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index d851eae3fc..469de6ca8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.io.IOException
+
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
@@ -266,6 +268,7 @@ case class GeneratedAggregate(
aggregationBufferSchema,
groupKeySchema,
TaskContext.get.taskMemoryManager(),
+ SparkEnv.get.shuffleMemoryManager,
1024 * 16, // initial capacity
pageSizeBytes,
false // disable tracking of performance metrics
@@ -275,6 +278,9 @@ case class GeneratedAggregate(
val currentRow: InternalRow = iter.next()
val groupKey: InternalRow = groupProjection(currentRow)
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
+ if (aggregationBuffer == null) {
+ throw new IOException("Could not allocate memory to grow aggregation buffer")
+ }
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index f88a45f48a..cc8bbfd2f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.joins
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
+import org.apache.spark.shuffle.ShuffleMemoryManager
import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
@@ -217,7 +219,7 @@ private[joins] final class UnsafeHashedRelation(
}
}
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(hashTable.size())
val iter = hashTable.entrySet().iterator()
@@ -256,16 +258,26 @@ private[joins] final class UnsafeHashedRelation(
}
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val nKeys = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
- val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+
+ // Dummy shuffle memory manager which always grants all memory allocation requests.
+ // We use this because it doesn't make sense count shared broadcast variables' memory usage
+ // towards individual tasks' quotas. In the future, we should devise a better way of handling
+ // this.
+ val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) {
+ override def tryToAcquire(numBytes: Long): Long = numBytes
+ override def release(numBytes: Long): Unit = {}
+ }
val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
.getSizeAsBytes("spark.buffer.pageSize", "64m")
binaryMap = new BytesToBytesMap(
- memoryManager,
+ taskMemoryManager,
+ shuffleMemoryManager,
nKeys * 2, // reduce hash collision
pageSizeBytes)
@@ -287,8 +299,11 @@ private[joins] final class UnsafeHashedRelation(
// put it into binary map
val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
assert(!loc.isDefined, "Duplicated key found!")
- loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
+ val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
+ if (!putSuceeded) {
+ throw new IOException("Could not allocate memory to grow BytesToBytesMap")
+ }
i += 1
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
new file mode 100644
index 0000000000..79fd52dacd
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.scalatest.{BeforeAndAfterEach, Matchers}
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.unsafe.types.UTF8String
+
+
+class UnsafeFixedWidthAggregationMapSuite
+ extends SparkFunSuite
+ with Matchers
+ with BeforeAndAfterEach {
+
+ import UnsafeFixedWidthAggregationMap._
+
+ private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
+ private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
+ private def emptyAggregationBuffer: InternalRow = InternalRow(0)
+ private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
+
+ private var taskMemoryManager: TaskMemoryManager = null
+ private var shuffleMemoryManager: ShuffleMemoryManager = null
+
+ override def beforeEach(): Unit = {
+ taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue)
+ }
+
+ override def afterEach(): Unit = {
+ if (taskMemoryManager != null) {
+ val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask
+ assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
+ assert(leakedShuffleMemory === 0)
+ taskMemoryManager = null
+ }
+ }
+
+ test("supported schemas") {
+ assert(supportsAggregationBufferSchema(
+ StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
+ assert(!supportsAggregationBufferSchema(
+ StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil)))
+ assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
+ assert(
+ !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
+ }
+
+ test("empty map") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 1024, // initial capacity,
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+ assert(!map.iterator().hasNext)
+ map.free()
+ }
+
+ test("updating values for a single key") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 1024, // initial capacity
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+ val groupKey = InternalRow(UTF8String.fromString("cats"))
+
+ // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
+ assert(map.getAggregationBuffer(groupKey) != null)
+ val iter = map.iterator()
+ val entry = iter.next()
+ assert(!iter.hasNext)
+ entry.key.getString(0) should be ("cats")
+ entry.value.getInt(0) should be (0)
+
+ // Modifications to rows retrieved from the map should update the values in the map
+ entry.value.setInt(0, 42)
+ map.getAggregationBuffer(groupKey).getInt(0) should be (42)
+
+ map.free()
+ }
+
+ test("inserting large random keys") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ 128, // initial capacity
+ PAGE_SIZE_BYTES,
+ false // disable perf metrics
+ )
+ val rand = new Random(42)
+ val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
+ groupKeys.foreach { keyString =>
+ assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null)
+ }
+ val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
+ entry.key.getString(0)
+ }.toSet
+ seenKeys.size should be (groupKeys.size)
+ seenKeys should be (groupKeys)
+
+ map.free()
+ }
+
+}