diff options
Diffstat (limited to 'sql/core/src/main')
3 files changed, 141 insertions, 160 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 8882903bbf..1f1b5389aa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -134,7 +134,7 @@ public final class UnsafeFixedWidthAggregationMap { if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: - boolean putSucceeded = loc.putNewKey( + boolean putSucceeded = loc.append( key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8cc3528639..dc4793e85a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} +import org.apache.spark.util.{KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer /** @@ -54,6 +54,11 @@ private[execution] sealed trait HashedRelation { */ def getMemorySize: Long = 1L // to make the test happy + /** + * Release any used resources. + */ + def close(): Unit = {} + // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { @@ -132,163 +137,83 @@ private[execution] object HashedRelation { } /** - * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key - * into a sequence of values. - * - * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use - * BytesToBytesMap for better memory performance (multiple values for the same are stored as a - * continuous byte array. + * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap. * * It's serialized in the following format: * [number of keys] - * [size of key] [size of all values in bytes] [key bytes] [bytes for all values] - * ... - * - * All the values are serialized as following: - * [number of fields] [number of bytes] [underlying bytes of UnsafeRow] - * ... + * [size of key] [size of value] [key bytes] [bytes for value] */ -private[joins] final class UnsafeHashedRelation( - private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) - extends HashedRelation - with KnownSizeEstimation - with Externalizable { - - private[joins] def this() = this(null) // Needed for serialization +private[joins] class UnsafeHashedRelation( + private var numFields: Int, + private var binaryMap: BytesToBytesMap) + extends HashedRelation with KnownSizeEstimation with Externalizable { - // Use BytesToBytesMap in executor for better performance (it's created when deserialization) - // This is used in broadcast joins and distributed mode only - @transient private[this] var binaryMap: BytesToBytesMap = _ + private[joins] def this() = this(0, null) // Needed for serialization - /** - * Return the size of the unsafe map on the executors. - * - * For broadcast joins, this hashed relation is bigger on the driver because it is - * represented as a Java hash map there. While serializing the map to the executors, - * however, we rehash the contents in a binary map to reduce the memory footprint on - * the executors. - * - * For non-broadcast joins or in local mode, return 0. - */ override def getMemorySize: Long = { - if (binaryMap != null) { - binaryMap.getTotalMemoryConsumption - } else { - 0 - } + binaryMap.getTotalMemoryConsumption } override def estimatedSize: Long = { - if (binaryMap != null) { - binaryMap.getTotalMemoryConsumption - } else { - SizeEstimator.estimate(hashTable) - } + binaryMap.getTotalMemoryConsumption } override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] - - if (binaryMap != null) { - // Used in Broadcast join - val map = binaryMap // avoid the compiler error - val loc = new map.Location // this could be allocated in stack - binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) - if (loc.isDefined) { - val buffer = CompactBuffer[UnsafeRow]() - - val base = loc.getValueBase - var offset = loc.getValueOffset - val last = offset + loc.getValueLength - while (offset < last) { - val numFields = Platform.getInt(base, offset) - val sizeInBytes = Platform.getInt(base, offset + 4) - offset += 8 - - val row = new UnsafeRow(numFields) - row.pointTo(base, offset, sizeInBytes) - buffer += row - offset += sizeInBytes - } - buffer - } else { - null + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + val buffer = CompactBuffer[UnsafeRow]() + val row = new UnsafeRow(numFields) + row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + buffer += row + while (loc.nextValue()) { + val row = new UnsafeRow(numFields) + row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + buffer += row } - + buffer } else { - // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin) - hashTable.get(unsafeKey) + null } } - override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - if (binaryMap != null) { - // This could happen when a cached broadcast object need to be dumped into disk to free memory - out.writeInt(binaryMap.numElements()) - - var buffer = new Array[Byte](64) - def write(base: Object, offset: Long, length: Int): Unit = { - if (buffer.length < length) { - buffer = new Array[Byte](length) - } - Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length) - out.write(buffer, 0, length) - } + override def close(): Unit = { + binaryMap.free() + } - val iter = binaryMap.iterator() - while (iter.hasNext) { - val loc = iter.next() - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(loc.getKeyLength) - out.writeInt(loc.getValueLength) - write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength) - write(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + out.writeInt(numFields) + // TODO: move these into BytesToBytesMap + out.writeInt(binaryMap.numKeys()) + out.writeInt(binaryMap.numValues()) + + var buffer = new Array[Byte](64) + def write(base: Object, offset: Long, length: Int): Unit = { + if (buffer.length < length) { + buffer = new Array[Byte](length) } + Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length) + out.write(buffer, 0, length) + } - } else { - assert(hashTable != null) - out.writeInt(hashTable.size()) - - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - val key = entry.getKey - val values = entry.getValue - - // write all the values as single byte array - var totalSize = 0L - var i = 0 - while (i < values.length) { - totalSize += values(i).getSizeInBytes + 4 + 4 - i += 1 - } - assert(totalSize < Integer.MAX_VALUE, "values are too big") - - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(key.getSizeInBytes) - out.writeInt(totalSize.toInt) - out.write(key.getBytes) - i = 0 - while (i < values.length) { - // [num of fields] [num of bytes] [row bytes] - // write the integer in native order, so they can be read by UNSAFE.getInt() - if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { - out.writeInt(values(i).numFields()) - out.writeInt(values(i).getSizeInBytes) - } else { - out.writeInt(Integer.reverseBytes(values(i).numFields())) - out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) - } - out.write(values(i).getBytes) - i += 1 - } - } + val iter = binaryMap.iterator() + while (iter.hasNext) { + val loc = iter.next() + // [key size] [values size] [key bytes] [value bytes] + out.writeInt(loc.getKeyLength) + out.writeInt(loc.getValueLength) + write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength) + write(loc.getValueBase, loc.getValueOffset, loc.getValueLength) } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + numFields = in.readInt() val nKeys = in.readInt() + val nValues = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory // TODO(josh): This needs to be revisited before we merge this patch; making this change now // so that tests compile: @@ -314,7 +239,7 @@ private[joins] final class UnsafeHashedRelation( var i = 0 var keyBuffer = new Array[Byte](1024) var valuesBuffer = new Array[Byte](1024) - while (i < nKeys) { + while (i < nValues) { val keySize = in.readInt() val valuesSize = in.readInt() if (keySize > keyBuffer.length) { @@ -326,13 +251,11 @@ private[joins] final class UnsafeHashedRelation( } in.readFully(valuesBuffer, 0, valuesSize) - // put it into binary map val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) - assert(!loc.isDefined, "Duplicated key found!") - val putSuceeded = loc.putNewKey( - keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, + val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize) if (!putSuceeded) { + binaryMap.free() throw new IOException("Could not allocate memory to grow BytesToBytesMap") } i += 1 @@ -340,6 +263,29 @@ private[joins] final class UnsafeHashedRelation( } } +/** + * A HashedRelation for UnsafeRow with unique keys. + */ +private[joins] final class UniqueUnsafeHashedRelation( + private var numFields: Int, + private var binaryMap: BytesToBytesMap) + extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation { + def getValue(key: InternalRow): InternalRow = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + val row = new UnsafeRow(numFields) + row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + row + } else { + null + } + } +} + private[joins] object UnsafeHashedRelation { def apply( @@ -347,29 +293,54 @@ private[joins] object UnsafeHashedRelation { keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { - // Use a Java hash table here because unsafe maps expect fixed size records - // TODO: Use BytesToBytesMap for memory efficiency - val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + val taskMemoryManager = if (TaskContext.get() != null) { + TaskContext.get().taskMemoryManager() + } else { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + } + val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) + .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) + + val binaryMap = new BytesToBytesMap( + taskMemoryManager, + // Only 70% of the slots can be used before growing, more capacity help to reduce collision + (sizeEstimate * 1.5 + 1).toInt, + pageSizeBytes) // Create a mapping of buildKeys -> rows + var numFields = 0 + // Whether all the keys are unique or not + var allUnique: Boolean = true while (input.hasNext) { - val unsafeRow = input.next().asInstanceOf[UnsafeRow] - val rowKey = keyGenerator(unsafeRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[UnsafeRow]() - hashTable.put(rowKey.copy(), newMatchList) - newMatchList - } else { - existingMatchList + val row = input.next().asInstanceOf[UnsafeRow] + numFields = row.numFields() + val key = keyGenerator(row) + if (!key.anyNull) { + val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + if (loc.isDefined) { + allUnique = false + } + val success = loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) + if (!success) { + binaryMap.free() + throw new SparkException("There is no enough memory to build hash map") } - matchList += unsafeRow } } - // TODO: create UniqueUnsafeRelation - new UnsafeHashedRelation(hashTable) + if (allUnique) { + new UniqueUnsafeHashedRelation(numFields, binaryMap) + } else { + new UnsafeHashedRelation(numFields, binaryMap) + } } } @@ -523,7 +494,7 @@ private[joins] object LongHashedRelation { keyGenerator: Projection, sizeEstimate: Int): HashedRelation = { - // Use a Java hash table here because unsafe maps expect fixed size records + // TODO: use LongToBytesMap for better memory efficiency val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of key -> rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5c4f1ef60f..e3a2eaea5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -57,9 +57,19 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { + val context = TaskContext.get() + if (!canJoinKeyFitWithinLong) { + // build BytesToBytesMap + val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator) + // This relation is usually used until the end of task. + context.addTaskCompletionListener((t: TaskContext) => + relation.close() + ) + return relation + } + // try to acquire some memory for the hash table, it could trigger other operator to free some // memory. The memory acquired here will mostly be used until the end of task. - val context = TaskContext.get() val memoryManager = context.taskMemoryManager() var acquired = 0L var used = 0L @@ -69,18 +79,18 @@ case class ShuffledHashJoin( val copiedIter = iter.map { row => // It's hard to guess what's exactly memory will be used, we have a rough guess here. - // TODO: use BytesToBytesMap instead of HashMap for memory efficiency - // Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers + // TODO: use LongToBytesMap instead of HashMap for memory efficiency + // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers val needed = 150 + row.getSizeInBytes if (needed > acquired - used) { val got = memoryManager.acquireExecutionMemory( Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) + acquired += got if (got < needed) { throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + "hash join, please use sort merge join by setting " + "spark.sql.join.preferSortMergeJoin=true") } - acquired += got } used += needed // HashedRelation requires that the UnsafeRow should be separate objects. |