diff options
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala | 944 |
1 files changed, 539 insertions, 405 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8cc3528639..0427db4e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,277 +18,189 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.nio.ByteOrder -import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.{KnownSizeEstimation, Utils} /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[execution] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation extends KnownSizeEstimation { /** - * Returns matched rows. - */ - def get(key: InternalRow): Seq[InternalRow] + * Returns matched rows. + * + * Returns null if there is no matched rows. + */ + def get(key: InternalRow): Iterator[InternalRow] /** - * Returns matched rows for a key that has only one column with LongType. - */ - def get(key: Long): Seq[InternalRow] = { + * Returns matched rows for a key that has only one column with LongType. + * + * Returns null if there is no matched rows. + */ + def get(key: Long): Iterator[InternalRow] = { throw new UnsupportedOperationException } /** - * Returns the size of used memory. - */ - def getMemorySize: Long = 1L // to make the test happy - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { - out.writeInt(serialized.length) // Write the length of serialized bytes first - out.write(serialized) - } - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def readBytes(in: ObjectInput): Array[Byte] = { - val serializedSize = in.readInt() // Read the length of serialized bytes first - val bytes = new Array[Byte](serializedSize) - in.readFully(bytes) - bytes - } -} - -/** - * Interface for a hashed relation that have only one row per key. - * - * We should call getValue() for better performance. - */ -private[execution] trait UniqueHashedRelation extends HashedRelation { - - /** - * Returns the matched single row. - */ + * Returns the matched single row. + */ def getValue(key: InternalRow): InternalRow /** - * Returns the matched single row with key that have only one column of LongType. - */ + * Returns the matched single row with key that have only one column of LongType. + */ def getValue(key: Long): InternalRow = { throw new UnsupportedOperationException } - override def get(key: InternalRow): Seq[InternalRow] = { - val row = getValue(key) - if (row != null) { - CompactBuffer[InternalRow](row) - } else { - null - } - } + /** + * Returns true iff all the keys are unique. + */ + def keyIsUnique: Boolean - override def get(key: Long): Seq[InternalRow] = { - val row = getValue(key) - if (row != null) { - CompactBuffer[InternalRow](row) - } else { - null - } - } + /** + * Returns a read-only copy of this, to be safely used in current thread. + */ + def asReadOnlyCopy(): HashedRelation + + /** + * Release any used resources. + */ + def close(): Unit } private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. - * - * Note: The caller should make sure that these InternalRow are different objects. */ def apply( - canJoinKeyFitWithinLong: Boolean, input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int = 64): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int = 64, + taskMemoryManager: TaskMemoryManager = null): HashedRelation = { + val mm = Option(taskMemoryManager).getOrElse { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + } - if (canJoinKeyFitWithinLong) { - LongHashedRelation(input, keyGenerator, sizeEstimate) + if (key.length == 1 && key.head.dataType == LongType) { + LongHashedRelation(input, key, sizeEstimate, mm) } else { - UnsafeHashedRelation( - input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + UnsafeHashedRelation(input, key, sizeEstimate, mm) } } } /** - * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key - * into a sequence of values. - * - * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use - * BytesToBytesMap for better memory performance (multiple values for the same are stored as a - * continuous byte array. + * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap. * * It's serialized in the following format: * [number of keys] - * [size of key] [size of all values in bytes] [key bytes] [bytes for all values] - * ... - * - * All the values are serialized as following: - * [number of fields] [number of bytes] [underlying bytes of UnsafeRow] - * ... + * [size of key] [size of value] [key bytes] [bytes for value] */ -private[joins] final class UnsafeHashedRelation( - private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) - extends HashedRelation - with KnownSizeEstimation - with Externalizable { +private[joins] class UnsafeHashedRelation( + private var numFields: Int, + private var binaryMap: BytesToBytesMap) + extends HashedRelation with Externalizable { - private[joins] def this() = this(null) // Needed for serialization + private[joins] def this() = this(0, null) // Needed for serialization - // Use BytesToBytesMap in executor for better performance (it's created when deserialization) - // This is used in broadcast joins and distributed mode only - @transient private[this] var binaryMap: BytesToBytesMap = _ + override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues() - /** - * Return the size of the unsafe map on the executors. - * - * For broadcast joins, this hashed relation is bigger on the driver because it is - * represented as a Java hash map there. While serializing the map to the executors, - * however, we rehash the contents in a binary map to reduce the memory footprint on - * the executors. - * - * For non-broadcast joins or in local mode, return 0. - */ - override def getMemorySize: Long = { - if (binaryMap != null) { - binaryMap.getTotalMemoryConsumption - } else { - 0 - } + override def asReadOnlyCopy(): UnsafeHashedRelation = { + new UnsafeHashedRelation(numFields, binaryMap) } - override def estimatedSize: Long = { - if (binaryMap != null) { - binaryMap.getTotalMemoryConsumption - } else { - SizeEstimator.estimate(hashTable) - } - } + override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption - override def get(key: InternalRow): Seq[InternalRow] = { - val unsafeKey = key.asInstanceOf[UnsafeRow] + // re-used in get()/getValue() + var resultRow = new UnsafeRow(numFields) - if (binaryMap != null) { - // Used in Broadcast join - val map = binaryMap // avoid the compiler error - val loc = new map.Location // this could be allocated in stack - binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) - if (loc.isDefined) { - val buffer = CompactBuffer[UnsafeRow]() - - val base = loc.getValueBase - var offset = loc.getValueOffset - val last = offset + loc.getValueLength - while (offset < last) { - val numFields = Platform.getInt(base, offset) - val sizeInBytes = Platform.getInt(base, offset + 4) - offset += 8 - - val row = new UnsafeRow(numFields) - row.pointTo(base, offset, sizeInBytes) - buffer += row - offset += sizeInBytes + override def get(key: InternalRow): Iterator[InternalRow] = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + new Iterator[UnsafeRow] { + private var _hasNext = true + override def hasNext: Boolean = _hasNext + override def next(): UnsafeRow = { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + _hasNext = loc.nextValue() + resultRow } - buffer - } else { - null } + } else { + null + } + } + def getValue(key: InternalRow): InternalRow = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + resultRow } else { - // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin) - hashTable.get(unsafeKey) + null } } - override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - if (binaryMap != null) { - // This could happen when a cached broadcast object need to be dumped into disk to free memory - out.writeInt(binaryMap.numElements()) - - var buffer = new Array[Byte](64) - def write(base: Object, offset: Long, length: Int): Unit = { - if (buffer.length < length) { - buffer = new Array[Byte](length) - } - Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length) - out.write(buffer, 0, length) - } + override def close(): Unit = { + binaryMap.free() + } - val iter = binaryMap.iterator() - while (iter.hasNext) { - val loc = iter.next() - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(loc.getKeyLength) - out.writeInt(loc.getValueLength) - write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength) - write(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + out.writeInt(numFields) + // TODO: move these into BytesToBytesMap + out.writeInt(binaryMap.numKeys()) + out.writeInt(binaryMap.numValues()) + + var buffer = new Array[Byte](64) + def write(base: Object, offset: Long, length: Int): Unit = { + if (buffer.length < length) { + buffer = new Array[Byte](length) } + Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length) + out.write(buffer, 0, length) + } - } else { - assert(hashTable != null) - out.writeInt(hashTable.size()) - - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - val key = entry.getKey - val values = entry.getValue - - // write all the values as single byte array - var totalSize = 0L - var i = 0 - while (i < values.length) { - totalSize += values(i).getSizeInBytes + 4 + 4 - i += 1 - } - assert(totalSize < Integer.MAX_VALUE, "values are too big") - - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(key.getSizeInBytes) - out.writeInt(totalSize.toInt) - out.write(key.getBytes) - i = 0 - while (i < values.length) { - // [num of fields] [num of bytes] [row bytes] - // write the integer in native order, so they can be read by UNSAFE.getInt() - if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { - out.writeInt(values(i).numFields()) - out.writeInt(values(i).getSizeInBytes) - } else { - out.writeInt(Integer.reverseBytes(values(i).numFields())) - out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) - } - out.write(values(i).getBytes) - i += 1 - } - } + val iter = binaryMap.iterator() + while (iter.hasNext) { + val loc = iter.next() + // [key size] [values size] [key bytes] [value bytes] + out.writeInt(loc.getKeyLength) + out.writeInt(loc.getValueLength) + write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength) + write(loc.getValueBase, loc.getValueOffset, loc.getValueLength) } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + numFields = in.readInt() + resultRow = new UnsafeRow(numFields) val nKeys = in.readInt() + val nValues = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory // TODO(josh): This needs to be revisited before we merge this patch; making this change now // so that tests compile: @@ -314,7 +226,7 @@ private[joins] final class UnsafeHashedRelation( var i = 0 var keyBuffer = new Array[Byte](1024) var valuesBuffer = new Array[Byte](1024) - while (i < nKeys) { + while (i < nValues) { val keySize = in.readInt() val valuesSize = in.readInt() if (keySize > keyBuffer.length) { @@ -326,13 +238,11 @@ private[joins] final class UnsafeHashedRelation( } in.readFully(valuesBuffer, 0, valuesSize) - // put it into binary map val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) - assert(!loc.isDefined, "Duplicated key found!") - val putSuceeded = loc.putNewKey( - keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, + val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize) if (!putSuceeded) { + binaryMap.free() throw new IOException("Could not allocate memory to grow BytesToBytesMap") } i += 1 @@ -344,279 +254,503 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - keyGenerator: UnsafeProjection, - sizeEstimate: Int): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): HashedRelation = { - // Use a Java hash table here because unsafe maps expect fixed size records - // TODO: Use BytesToBytesMap for memory efficiency - val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) + .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) + + val binaryMap = new BytesToBytesMap( + taskMemoryManager, + // Only 70% of the slots can be used before growing, more capacity help to reduce collision + (sizeEstimate * 1.5 + 1).toInt, + pageSizeBytes) // Create a mapping of buildKeys -> rows + val keyGenerator = UnsafeProjection.create(key) + var numFields = 0 while (input.hasNext) { - val unsafeRow = input.next().asInstanceOf[UnsafeRow] - val rowKey = keyGenerator(unsafeRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[UnsafeRow]() - hashTable.put(rowKey.copy(), newMatchList) - newMatchList - } else { - existingMatchList + val row = input.next().asInstanceOf[UnsafeRow] + numFields = row.numFields() + val key = keyGenerator(row) + if (!key.anyNull) { + val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + val success = loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) + if (!success) { + binaryMap.free() + throw new SparkException("There is no enough memory to build hash map") } - matchList += unsafeRow } } - // TODO: create UniqueUnsafeRelation - new UnsafeHashedRelation(hashTable) + new UnsafeHashedRelation(numFields, binaryMap) } } /** - * An interface for a hashed relation that the key is a Long. - */ -private[joins] trait LongHashedRelation extends HashedRelation { - override def get(key: InternalRow): Seq[InternalRow] = { - get(key.getLong(0)) + * An append-only hash map mapping from key of Long to UnsafeRow. + * + * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array + * (`page`) in this format: + * + * [bytes of row1][address1][bytes of row2][address1] ... + * + * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key + * could have multiple values. the address at the end of last value for every key is 0. + * + * The keys and addresses of their values could be stored in two modes: + * + * 1) sparse mode: the keys and addresses are stored in `array` as: + * + * [key1][address1][key2][address2]...[] + * + * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1 + * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address + * hash collision. + * + * 2) dense mode: all the addresses are packed into a single array of long, as: + * + * [address1] [address2] ... + * + * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is + * determined by `key1 - minKey`. + * + * The map is created as sparse mode, then key-value could be appended into it. Once finish + * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * to probe. + * + * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ + */ +private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) + extends MemoryConsumer(mm) with Externalizable { + + // Whether the keys are stored in dense mode or not. + private var isDense = false + + // The minimum key + private var minKey = Long.MaxValue + + // The maxinum key + private var maxKey = Long.MinValue + + // The array to store the key and offset of UnsafeRow in the page. + // + // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... + // Dense mode: [offset1 | size1] [offset2 | size2] + private var array: Array[Long] = null + private var mask: Int = 0 + + // The page to store all bytes of UnsafeRow and the pointer to next rows. + // [row1][pointer1] [row2][pointer2] + private var page: Array[Byte] = null + + // Current write cursor in the page. + private var cursor = Platform.BYTE_ARRAY_OFFSET + + // The total number of values of all keys. + private var numValues = 0 + + // The number of unique keys. + private var numKeys = 0 + + // needed by serializer + def this() = { + this( + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0), + 0) } -} -private[joins] final class GeneralLongHashedRelation( - private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) - extends LongHashedRelation with Externalizable { + private def acquireMemory(size: Long): Unit = { + // do not support spilling + val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) + if (got < size) { + freeMemory(got) + throw new SparkException(s"Can't acquire $size bytes memory to build hash relation, " + + s"got $got bytes") + } + } - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) + private def freeMemory(size: Long): Unit = { + mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) + } + + private def init(): Unit = { + if (mm != null) { + var n = 1 + while (n < capacity) n *= 2 + acquireMemory(n * 2 * 8 + (1 << 20)) + array = new Array[Long](n * 2) + mask = n * 2 - 2 + page = new Array[Byte](1 << 20) // 1M bytes + } + } - override def get(key: Long): Seq[InternalRow] = hashTable.get(key) + init() - override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + def spill(size: Long, trigger: MemoryConsumer): Long = 0L + + /** + * Returns whether all the keys are unique. + */ + def keyIsUnique: Boolean = numKeys == numValues + + /** + * Returns total memory consumption. + */ + def getTotalMemoryConsumption: Long = array.length * 8 + page.length + + /** + * Returns the first slot of array that store the keys (sparse mode). + */ + private def firstSlot(key: Long): Int = { + val h = key * 0x9E3779B9L + (h ^ (h >> 32)).toInt & mask } - override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + /** + * Returns the next probe in the array. + */ + private def nextSlot(pos: Int): Int = (pos + 2) & mask + + private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { + val offset = address >>> 32 + val size = address & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + resultRow } -} -private[joins] final class UniqueLongHashedRelation( - private var hashTable: JavaHashMap[Long, UnsafeRow]) - extends UniqueHashedRelation with LongHashedRelation with Externalizable { + /** + * Returns the single UnsafeRow for given key, or null if not found. + */ + def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >= 0 && key <= maxKey && array(idx) > 0) { + return getRow(array(idx), resultRow) + } + } else { + var pos = firstSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return getRow(array(pos + 1), resultRow) + } + pos = nextSlot(pos) + } + } + null + } - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) + /** + * Returns an interator of UnsafeRow for multiple linked values. + */ + private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + new Iterator[UnsafeRow] { + var addr = address + override def hasNext: Boolean = addr != 0 + override def next(): UnsafeRow = { + val offset = addr >>> 32 + val size = addr & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + addr = Platform.getLong(page, offset + size) + resultRow + } + } + } - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) + /** + * Returns an iterator for all the values for the given key, or null if no value found. + */ + def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >=0 && key <= maxKey && array(idx) > 0) { + return valueIter(array(idx), resultRow) + } + } else { + var pos = firstSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return valueIter(array(pos + 1), resultRow) + } + pos = nextSlot(pos) + } + } + null } - override def getValue(key: Long): InternalRow = { - hashTable.get(key) + /** + * Appends the key and row into this map. + */ + def append(key: Long, row: UnsafeRow): Unit = { + if (key < minKey) { + minKey = key + } + if (key > maxKey) { + maxKey = key + } + + // There is 8 bytes for the pointer to next value + if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { + val used = page.length + if (used * 2L > (1L << 31)) { + sys.error("Can't allocate a page that is larger than 2G") + } + acquireMemory(used * 2) + val newPage = new Array[Byte](used * 2) + System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) + page = newPage + freeMemory(used) + } + + // copy the bytes of UnsafeRow + val offset = cursor + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + cursor += row.getSizeInBytes + Platform.putLong(page, cursor, 0) + cursor += 8 + numValues += 1 + updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) + } + + /** + * Update the address in array for given key. + */ + private def updateIndex(key: Long, address: Long): Unit = { + var pos = firstSlot(key) + while (array(pos) != key && array(pos + 1) != 0) { + pos = nextSlot(pos) + } + if (array(pos + 1) == 0) { + // this is the first value for this key, put the address in array. + array(pos) = key + array(pos + 1) = address + numKeys += 1 + if (numKeys * 4 > array.length) { + // reach half of the capacity + growArray() + } + } else { + // there are some values for this key, put the address in the front of them. + val pointer = (address >>> 32) + (address & 0xffffffffL) + Platform.putLong(page, pointer, array(pos + 1)) + array(pos + 1) = address + } + } + + private def growArray(): Unit = { + var old_array = array + val n = array.length + numKeys = 0 + acquireMemory(n * 2 * 8) + array = new Array[Long](n * 2) + mask = n * 2 - 2 + var i = 0 + while (i < old_array.length) { + if (old_array(i + 1) > 0) { + updateIndex(old_array(i), old_array(i + 1)) + } + i += 2 + } + old_array = null // release the reference to old array + freeMemory(n * 8) + } + + /** + * Try to turn the map into dense mode, which is faster to probe. + */ + def optimize(): Unit = { + val range = maxKey - minKey + // Convert to dense mode if it does not require more memory or could fit within L1 cache + if (range < array.length || range < 1024) { + try { + acquireMemory((range + 1) * 8) + } catch { + case e: SparkException => + // there is no enough memory to convert + return + } + val denseArray = new Array[Long]((range + 1).toInt) + var i = 0 + while (i < array.length) { + if (array(i + 1) > 0) { + val idx = (array(i) - minKey).toInt + denseArray(idx) = array(i + 1) + } + i += 2 + } + val old_length = array.length + array = denseArray + isDense = true + freeMemory(old_length * 8) + } + } + + /** + * Free all the memory acquired by this map. + */ + def free(): Unit = { + if (page != null) { + freeMemory(page.length) + page = null + } + if (array != null) { + freeMemory(array.length * 8) + array = null + } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeBoolean(isDense) + out.writeLong(minKey) + out.writeLong(maxKey) + out.writeInt(numKeys) + out.writeInt(numValues) + + out.writeInt(array.length) + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + out.write(buffer, 0, size) + offset += size + } + + val used = cursor - Platform.BYTE_ARRAY_OFFSET + out.writeInt(used) + out.write(page, 0, used) } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + isDense = in.readBoolean() + minKey = in.readLong() + maxKey = in.readLong() + numKeys = in.readInt() + numValues = in.readInt() + + val length = in.readInt() + array = new Array[Long](length) + mask = length - 2 + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + in.readFully(buffer, 0, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + offset += size + } + + val numBytes = in.readInt() + page = new Array[Byte](numBytes) + in.readFully(page) } } -/** - * A relation that pack all the rows into a byte array, together with offsets and sizes. - * - * All the bytes of UnsafeRow are packed together as `bytes`: - * - * [ Row0 ][ Row1 ][] ... [ RowN ] - * - * With keys: - * - * start start+1 ... start+N - * - * `offsets` are offsets of UnsafeRows in the `bytes` - * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. - * - * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: - * - * start = 3 - * offsets = [0, 0, 24] - * sizes = [24, 0, 32] - * bytes = [0 - 24][][24 - 56] - */ -private[joins] final class LongArrayRelation( - private var numFields: Int, - private var start: Long, - private var offsets: Array[Int], - private var sizes: Array[Int], - private var bytes: Array[Byte] - ) extends UniqueHashedRelation with LongHashedRelation with Externalizable { +private[joins] class LongHashedRelation( + private var nFields: Int, + private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { + + private var resultRow: UnsafeRow = new UnsafeRow(nFields) // Needed for serialization (it is public to make Java serialization work) - def this() = this(0, 0L, null, null, null) + def this() = this(0, null) - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) - } + override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) - override def getMemorySize: Long = { - offsets.length * 4 + sizes.length * 4 + bytes.length - } + override def estimatedSize: Long = map.getTotalMemoryConsumption - override def getValue(key: Long): InternalRow = { - val idx = (key - start).toInt - if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { - val result = new UnsafeRow(numFields) - result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) - result + override def get(key: InternalRow): Iterator[InternalRow] = { + if (key.isNullAt(0)) { + null } else { + get(key.getLong(0)) + } + } + + override def getValue(key: InternalRow): InternalRow = { + if (key.isNullAt(0)) { null + } else { + getValue(key.getLong(0)) } } + override def get(key: Long): Iterator[InternalRow] = map.get(key, resultRow) + + override def getValue(key: Long): InternalRow = map.getValue(key, resultRow) + + override def keyIsUnique: Boolean = map.keyIsUnique + + override def close(): Unit = { + map.free() + } + override def writeExternal(out: ObjectOutput): Unit = { - out.writeInt(numFields) - out.writeLong(start) - out.writeInt(sizes.length) - var i = 0 - while (i < sizes.length) { - out.writeInt(sizes(i)) - i += 1 - } - out.writeInt(bytes.length) - out.write(bytes) + out.writeInt(nFields) + out.writeObject(map) } override def readExternal(in: ObjectInput): Unit = { - numFields = in.readInt() - start = in.readLong() - val length = in.readInt() - // read sizes of rows - sizes = new Array[Int](length) - offsets = new Array[Int](length) - var i = 0 - var offset = 0 - while (i < length) { - offsets(i) = offset - sizes(i) = in.readInt() - offset += sizes(i) - i += 1 - } - // read all the bytes - val total = in.readInt() - assert(total == offset) - bytes = new Array[Byte](total) - in.readFully(bytes) + nFields = in.readInt() + resultRow = new UnsafeRow(nFields) + map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } } /** - * Create hashed relation with key that is long. - */ + * Create hashed relation with key that is long. + */ private[joins] object LongHashedRelation { - - val DENSE_FACTOR = 0.2 - def apply( - input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int): HashedRelation = { + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): LongHashedRelation = { - // Use a Java hash table here because unsafe maps expect fixed size records - val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) + val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) + val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows var numFields = 0 - var keyIsUnique = true - var minKey = Long.MaxValue - var maxKey = Long.MinValue while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) - if (!rowKey.anyNull) { + if (!rowKey.isNullAt(0)) { val key = rowKey.getLong(0) - minKey = math.min(minKey, key) - maxKey = math.max(maxKey, key) - val existingMatchList = hashTable.get(key) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[UnsafeRow]() - hashTable.put(key, newMatchList) - newMatchList - } else { - keyIsUnique = false - existingMatchList - } - matchList += unsafeRow + map.append(key, unsafeRow) } } - - if (keyIsUnique) { - if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { - // The keys are dense enough, so use LongArrayRelation - val length = (maxKey - minKey).toInt + 1 - val sizes = new Array[Int](length) - val offsets = new Array[Int](length) - var offset = 0 - var i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - offsets(i) = offset - sizes(i) = rows(0).getSizeInBytes - offset += sizes(i) - } - i += 1 - } - val bytes = new Array[Byte](offset) - i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) - } - i += 1 - } - new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) - - } else { - // all the keys are unique, one row per key. - val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size) - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - uniqHashTable.put(entry.getKey, entry.getValue()(0)) - } - new UniqueLongHashedRelation(uniqHashTable) - } - } else { - new GeneralLongHashedRelation(hashTable) - } + map.optimize() + new LongHashedRelation(numFields, map) } } /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ -private[execution] case class HashedRelationBroadcastMode( - canJoinKeyFitWithinLong: Boolean, - keys: Seq[Expression], - attributes: Seq[Attribute]) extends BroadcastMode { +private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) + extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - val generator = UnsafeProjection.create(keys, attributes) - HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length) + HashedRelation(rows.iterator, canonicalizedKey, rows.length) } - private lazy val canonicalizedKeys: Seq[Expression] = { - keys.map { e => - BindReferences.bindReference(e.canonicalized, attributes) - } + private lazy val canonicalizedKey: Seq[Expression] = { + key.map { e => e.canonicalized } } override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => - canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && - canonicalizedKeys == m.canonicalizedKeys + case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey case _ => false } } |