aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala944
1 files changed, 539 insertions, 405 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8cc3528639..0427db4e3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,277 +18,189 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.nio.ByteOrder
-import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.{SparkConf, SparkEnv}
-import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException}
+import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.types.LongType
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.util.{KnownSizeEstimation, Utils}
/**
* Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
* object.
*/
-private[execution] sealed trait HashedRelation {
+private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
/**
- * Returns matched rows.
- */
- def get(key: InternalRow): Seq[InternalRow]
+ * Returns matched rows.
+ *
+ * Returns null if there is no matched rows.
+ */
+ def get(key: InternalRow): Iterator[InternalRow]
/**
- * Returns matched rows for a key that has only one column with LongType.
- */
- def get(key: Long): Seq[InternalRow] = {
+ * Returns matched rows for a key that has only one column with LongType.
+ *
+ * Returns null if there is no matched rows.
+ */
+ def get(key: Long): Iterator[InternalRow] = {
throw new UnsupportedOperationException
}
/**
- * Returns the size of used memory.
- */
- def getMemorySize: Long = 1L // to make the test happy
-
- // This is a helper method to implement Externalizable, and is used by
- // GeneralHashedRelation and UniqueKeyHashedRelation
- protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
- out.writeInt(serialized.length) // Write the length of serialized bytes first
- out.write(serialized)
- }
-
- // This is a helper method to implement Externalizable, and is used by
- // GeneralHashedRelation and UniqueKeyHashedRelation
- protected def readBytes(in: ObjectInput): Array[Byte] = {
- val serializedSize = in.readInt() // Read the length of serialized bytes first
- val bytes = new Array[Byte](serializedSize)
- in.readFully(bytes)
- bytes
- }
-}
-
-/**
- * Interface for a hashed relation that have only one row per key.
- *
- * We should call getValue() for better performance.
- */
-private[execution] trait UniqueHashedRelation extends HashedRelation {
-
- /**
- * Returns the matched single row.
- */
+ * Returns the matched single row.
+ */
def getValue(key: InternalRow): InternalRow
/**
- * Returns the matched single row with key that have only one column of LongType.
- */
+ * Returns the matched single row with key that have only one column of LongType.
+ */
def getValue(key: Long): InternalRow = {
throw new UnsupportedOperationException
}
- override def get(key: InternalRow): Seq[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- CompactBuffer[InternalRow](row)
- } else {
- null
- }
- }
+ /**
+ * Returns true iff all the keys are unique.
+ */
+ def keyIsUnique: Boolean
- override def get(key: Long): Seq[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- CompactBuffer[InternalRow](row)
- } else {
- null
- }
- }
+ /**
+ * Returns a read-only copy of this, to be safely used in current thread.
+ */
+ def asReadOnlyCopy(): HashedRelation
+
+ /**
+ * Release any used resources.
+ */
+ def close(): Unit
}
private[execution] object HashedRelation {
/**
* Create a HashedRelation from an Iterator of InternalRow.
- *
- * Note: The caller should make sure that these InternalRow are different objects.
*/
def apply(
- canJoinKeyFitWithinLong: Boolean,
input: Iterator[InternalRow],
- keyGenerator: Projection,
- sizeEstimate: Int = 64): HashedRelation = {
+ key: Seq[Expression],
+ sizeEstimate: Int = 64,
+ taskMemoryManager: TaskMemoryManager = null): HashedRelation = {
+ val mm = Option(taskMemoryManager).getOrElse {
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ }
- if (canJoinKeyFitWithinLong) {
- LongHashedRelation(input, keyGenerator, sizeEstimate)
+ if (key.length == 1 && key.head.dataType == LongType) {
+ LongHashedRelation(input, key, sizeEstimate, mm)
} else {
- UnsafeHashedRelation(
- input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
+ UnsafeHashedRelation(input, key, sizeEstimate, mm)
}
}
}
/**
- * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
- * into a sequence of values.
- *
- * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
- * BytesToBytesMap for better memory performance (multiple values for the same are stored as a
- * continuous byte array.
+ * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap.
*
* It's serialized in the following format:
* [number of keys]
- * [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
- * ...
- *
- * All the values are serialized as following:
- * [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
- * ...
+ * [size of key] [size of value] [key bytes] [bytes for value]
*/
-private[joins] final class UnsafeHashedRelation(
- private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
- extends HashedRelation
- with KnownSizeEstimation
- with Externalizable {
+private[joins] class UnsafeHashedRelation(
+ private var numFields: Int,
+ private var binaryMap: BytesToBytesMap)
+ extends HashedRelation with Externalizable {
- private[joins] def this() = this(null) // Needed for serialization
+ private[joins] def this() = this(0, null) // Needed for serialization
- // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
- // This is used in broadcast joins and distributed mode only
- @transient private[this] var binaryMap: BytesToBytesMap = _
+ override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues()
- /**
- * Return the size of the unsafe map on the executors.
- *
- * For broadcast joins, this hashed relation is bigger on the driver because it is
- * represented as a Java hash map there. While serializing the map to the executors,
- * however, we rehash the contents in a binary map to reduce the memory footprint on
- * the executors.
- *
- * For non-broadcast joins or in local mode, return 0.
- */
- override def getMemorySize: Long = {
- if (binaryMap != null) {
- binaryMap.getTotalMemoryConsumption
- } else {
- 0
- }
+ override def asReadOnlyCopy(): UnsafeHashedRelation = {
+ new UnsafeHashedRelation(numFields, binaryMap)
}
- override def estimatedSize: Long = {
- if (binaryMap != null) {
- binaryMap.getTotalMemoryConsumption
- } else {
- SizeEstimator.estimate(hashTable)
- }
- }
+ override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption
- override def get(key: InternalRow): Seq[InternalRow] = {
- val unsafeKey = key.asInstanceOf[UnsafeRow]
+ // re-used in get()/getValue()
+ var resultRow = new UnsafeRow(numFields)
- if (binaryMap != null) {
- // Used in Broadcast join
- val map = binaryMap // avoid the compiler error
- val loc = new map.Location // this could be allocated in stack
- binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
- if (loc.isDefined) {
- val buffer = CompactBuffer[UnsafeRow]()
-
- val base = loc.getValueBase
- var offset = loc.getValueOffset
- val last = offset + loc.getValueLength
- while (offset < last) {
- val numFields = Platform.getInt(base, offset)
- val sizeInBytes = Platform.getInt(base, offset + 4)
- offset += 8
-
- val row = new UnsafeRow(numFields)
- row.pointTo(base, offset, sizeInBytes)
- buffer += row
- offset += sizeInBytes
+ override def get(key: InternalRow): Iterator[InternalRow] = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ new Iterator[UnsafeRow] {
+ private var _hasNext = true
+ override def hasNext: Boolean = _hasNext
+ override def next(): UnsafeRow = {
+ resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ _hasNext = loc.nextValue()
+ resultRow
}
- buffer
- } else {
- null
}
+ } else {
+ null
+ }
+ }
+ def getValue(key: InternalRow): InternalRow = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ resultRow
} else {
- // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
- hashTable.get(unsafeKey)
+ null
}
}
- override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- if (binaryMap != null) {
- // This could happen when a cached broadcast object need to be dumped into disk to free memory
- out.writeInt(binaryMap.numElements())
-
- var buffer = new Array[Byte](64)
- def write(base: Object, offset: Long, length: Int): Unit = {
- if (buffer.length < length) {
- buffer = new Array[Byte](length)
- }
- Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
- out.write(buffer, 0, length)
- }
+ override def close(): Unit = {
+ binaryMap.free()
+ }
- val iter = binaryMap.iterator()
- while (iter.hasNext) {
- val loc = iter.next()
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(loc.getKeyLength)
- out.writeInt(loc.getValueLength)
- write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
- write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+ out.writeInt(numFields)
+ // TODO: move these into BytesToBytesMap
+ out.writeInt(binaryMap.numKeys())
+ out.writeInt(binaryMap.numValues())
+
+ var buffer = new Array[Byte](64)
+ def write(base: Object, offset: Long, length: Int): Unit = {
+ if (buffer.length < length) {
+ buffer = new Array[Byte](length)
}
+ Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
+ out.write(buffer, 0, length)
+ }
- } else {
- assert(hashTable != null)
- out.writeInt(hashTable.size())
-
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val key = entry.getKey
- val values = entry.getValue
-
- // write all the values as single byte array
- var totalSize = 0L
- var i = 0
- while (i < values.length) {
- totalSize += values(i).getSizeInBytes + 4 + 4
- i += 1
- }
- assert(totalSize < Integer.MAX_VALUE, "values are too big")
-
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(key.getSizeInBytes)
- out.writeInt(totalSize.toInt)
- out.write(key.getBytes)
- i = 0
- while (i < values.length) {
- // [num of fields] [num of bytes] [row bytes]
- // write the integer in native order, so they can be read by UNSAFE.getInt()
- if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
- out.writeInt(values(i).numFields())
- out.writeInt(values(i).getSizeInBytes)
- } else {
- out.writeInt(Integer.reverseBytes(values(i).numFields()))
- out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
- }
- out.write(values(i).getBytes)
- i += 1
- }
- }
+ val iter = binaryMap.iterator()
+ while (iter.hasNext) {
+ val loc = iter.next()
+ // [key size] [values size] [key bytes] [value bytes]
+ out.writeInt(loc.getKeyLength)
+ out.writeInt(loc.getValueLength)
+ write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+ write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
}
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+ numFields = in.readInt()
+ resultRow = new UnsafeRow(numFields)
val nKeys = in.readInt()
+ val nValues = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
// TODO(josh): This needs to be revisited before we merge this patch; making this change now
// so that tests compile:
@@ -314,7 +226,7 @@ private[joins] final class UnsafeHashedRelation(
var i = 0
var keyBuffer = new Array[Byte](1024)
var valuesBuffer = new Array[Byte](1024)
- while (i < nKeys) {
+ while (i < nValues) {
val keySize = in.readInt()
val valuesSize = in.readInt()
if (keySize > keyBuffer.length) {
@@ -326,13 +238,11 @@ private[joins] final class UnsafeHashedRelation(
}
in.readFully(valuesBuffer, 0, valuesSize)
- // put it into binary map
val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
- assert(!loc.isDefined, "Duplicated key found!")
- val putSuceeded = loc.putNewKey(
- keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
+ val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize)
if (!putSuceeded) {
+ binaryMap.free()
throw new IOException("Could not allocate memory to grow BytesToBytesMap")
}
i += 1
@@ -344,279 +254,503 @@ private[joins] object UnsafeHashedRelation {
def apply(
input: Iterator[InternalRow],
- keyGenerator: UnsafeProjection,
- sizeEstimate: Int): HashedRelation = {
+ key: Seq[Expression],
+ sizeEstimate: Int,
+ taskMemoryManager: TaskMemoryManager): HashedRelation = {
- // Use a Java hash table here because unsafe maps expect fixed size records
- // TODO: Use BytesToBytesMap for memory efficiency
- val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
+ .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
+
+ val binaryMap = new BytesToBytesMap(
+ taskMemoryManager,
+ // Only 70% of the slots can be used before growing, more capacity help to reduce collision
+ (sizeEstimate * 1.5 + 1).toInt,
+ pageSizeBytes)
// Create a mapping of buildKeys -> rows
+ val keyGenerator = UnsafeProjection.create(key)
+ var numFields = 0
while (input.hasNext) {
- val unsafeRow = input.next().asInstanceOf[UnsafeRow]
- val rowKey = keyGenerator(unsafeRow)
- if (!rowKey.anyNull) {
- val existingMatchList = hashTable.get(rowKey)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[UnsafeRow]()
- hashTable.put(rowKey.copy(), newMatchList)
- newMatchList
- } else {
- existingMatchList
+ val row = input.next().asInstanceOf[UnsafeRow]
+ numFields = row.numFields()
+ val key = keyGenerator(row)
+ if (!key.anyNull) {
+ val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ val success = loc.append(
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
+ if (!success) {
+ binaryMap.free()
+ throw new SparkException("There is no enough memory to build hash map")
}
- matchList += unsafeRow
}
}
- // TODO: create UniqueUnsafeRelation
- new UnsafeHashedRelation(hashTable)
+ new UnsafeHashedRelation(numFields, binaryMap)
}
}
/**
- * An interface for a hashed relation that the key is a Long.
- */
-private[joins] trait LongHashedRelation extends HashedRelation {
- override def get(key: InternalRow): Seq[InternalRow] = {
- get(key.getLong(0))
+ * An append-only hash map mapping from key of Long to UnsafeRow.
+ *
+ * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array
+ * (`page`) in this format:
+ *
+ * [bytes of row1][address1][bytes of row2][address1] ...
+ *
+ * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key
+ * could have multiple values. the address at the end of last value for every key is 0.
+ *
+ * The keys and addresses of their values could be stored in two modes:
+ *
+ * 1) sparse mode: the keys and addresses are stored in `array` as:
+ *
+ * [key1][address1][key2][address2]...[]
+ *
+ * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1
+ * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address
+ * hash collision.
+ *
+ * 2) dense mode: all the addresses are packed into a single array of long, as:
+ *
+ * [address1] [address2] ...
+ *
+ * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is
+ * determined by `key1 - minKey`.
+ *
+ * The map is created as sparse mode, then key-value could be appended into it. Once finish
+ * appending, caller could all optimize() to try to turn the map into dense mode, which is faster
+ * to probe.
+ *
+ * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/
+ */
+private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int)
+ extends MemoryConsumer(mm) with Externalizable {
+
+ // Whether the keys are stored in dense mode or not.
+ private var isDense = false
+
+ // The minimum key
+ private var minKey = Long.MaxValue
+
+ // The maxinum key
+ private var maxKey = Long.MinValue
+
+ // The array to store the key and offset of UnsafeRow in the page.
+ //
+ // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ...
+ // Dense mode: [offset1 | size1] [offset2 | size2]
+ private var array: Array[Long] = null
+ private var mask: Int = 0
+
+ // The page to store all bytes of UnsafeRow and the pointer to next rows.
+ // [row1][pointer1] [row2][pointer2]
+ private var page: Array[Byte] = null
+
+ // Current write cursor in the page.
+ private var cursor = Platform.BYTE_ARRAY_OFFSET
+
+ // The total number of values of all keys.
+ private var numValues = 0
+
+ // The number of unique keys.
+ private var numKeys = 0
+
+ // needed by serializer
+ def this() = {
+ this(
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0),
+ 0)
}
-}
-private[joins] final class GeneralLongHashedRelation(
- private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]])
- extends LongHashedRelation with Externalizable {
+ private def acquireMemory(size: Long): Unit = {
+ // do not support spilling
+ val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this)
+ if (got < size) {
+ freeMemory(got)
+ throw new SparkException(s"Can't acquire $size bytes memory to build hash relation, " +
+ s"got $got bytes")
+ }
+ }
- // Needed for serialization (it is public to make Java serialization work)
- def this() = this(null)
+ private def freeMemory(size: Long): Unit = {
+ mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this)
+ }
+
+ private def init(): Unit = {
+ if (mm != null) {
+ var n = 1
+ while (n < capacity) n *= 2
+ acquireMemory(n * 2 * 8 + (1 << 20))
+ array = new Array[Long](n * 2)
+ mask = n * 2 - 2
+ page = new Array[Byte](1 << 20) // 1M bytes
+ }
+ }
- override def get(key: Long): Seq[InternalRow] = hashTable.get(key)
+ init()
- override def writeExternal(out: ObjectOutput): Unit = {
- writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ def spill(size: Long, trigger: MemoryConsumer): Long = 0L
+
+ /**
+ * Returns whether all the keys are unique.
+ */
+ def keyIsUnique: Boolean = numKeys == numValues
+
+ /**
+ * Returns total memory consumption.
+ */
+ def getTotalMemoryConsumption: Long = array.length * 8 + page.length
+
+ /**
+ * Returns the first slot of array that store the keys (sparse mode).
+ */
+ private def firstSlot(key: Long): Int = {
+ val h = key * 0x9E3779B9L
+ (h ^ (h >> 32)).toInt & mask
}
- override def readExternal(in: ObjectInput): Unit = {
- hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ /**
+ * Returns the next probe in the array.
+ */
+ private def nextSlot(pos: Int): Int = (pos + 2) & mask
+
+ private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
+ val offset = address >>> 32
+ val size = address & 0xffffffffL
+ resultRow.pointTo(page, offset, size.toInt)
+ resultRow
}
-}
-private[joins] final class UniqueLongHashedRelation(
- private var hashTable: JavaHashMap[Long, UnsafeRow])
- extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+ /**
+ * Returns the single UnsafeRow for given key, or null if not found.
+ */
+ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
+ if (isDense) {
+ val idx = (key - minKey).toInt
+ if (idx >= 0 && key <= maxKey && array(idx) > 0) {
+ return getRow(array(idx), resultRow)
+ }
+ } else {
+ var pos = firstSlot(key)
+ while (array(pos + 1) != 0) {
+ if (array(pos) == key) {
+ return getRow(array(pos + 1), resultRow)
+ }
+ pos = nextSlot(pos)
+ }
+ }
+ null
+ }
- // Needed for serialization (it is public to make Java serialization work)
- def this() = this(null)
+ /**
+ * Returns an interator of UnsafeRow for multiple linked values.
+ */
+ private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
+ new Iterator[UnsafeRow] {
+ var addr = address
+ override def hasNext: Boolean = addr != 0
+ override def next(): UnsafeRow = {
+ val offset = addr >>> 32
+ val size = addr & 0xffffffffL
+ resultRow.pointTo(page, offset, size.toInt)
+ addr = Platform.getLong(page, offset + size)
+ resultRow
+ }
+ }
+ }
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
+ /**
+ * Returns an iterator for all the values for the given key, or null if no value found.
+ */
+ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
+ if (isDense) {
+ val idx = (key - minKey).toInt
+ if (idx >=0 && key <= maxKey && array(idx) > 0) {
+ return valueIter(array(idx), resultRow)
+ }
+ } else {
+ var pos = firstSlot(key)
+ while (array(pos + 1) != 0) {
+ if (array(pos) == key) {
+ return valueIter(array(pos + 1), resultRow)
+ }
+ pos = nextSlot(pos)
+ }
+ }
+ null
}
- override def getValue(key: Long): InternalRow = {
- hashTable.get(key)
+ /**
+ * Appends the key and row into this map.
+ */
+ def append(key: Long, row: UnsafeRow): Unit = {
+ if (key < minKey) {
+ minKey = key
+ }
+ if (key > maxKey) {
+ maxKey = key
+ }
+
+ // There is 8 bytes for the pointer to next value
+ if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) {
+ val used = page.length
+ if (used * 2L > (1L << 31)) {
+ sys.error("Can't allocate a page that is larger than 2G")
+ }
+ acquireMemory(used * 2)
+ val newPage = new Array[Byte](used * 2)
+ System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET)
+ page = newPage
+ freeMemory(used)
+ }
+
+ // copy the bytes of UnsafeRow
+ val offset = cursor
+ Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes)
+ cursor += row.getSizeInBytes
+ Platform.putLong(page, cursor, 0)
+ cursor += 8
+ numValues += 1
+ updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes)
+ }
+
+ /**
+ * Update the address in array for given key.
+ */
+ private def updateIndex(key: Long, address: Long): Unit = {
+ var pos = firstSlot(key)
+ while (array(pos) != key && array(pos + 1) != 0) {
+ pos = nextSlot(pos)
+ }
+ if (array(pos + 1) == 0) {
+ // this is the first value for this key, put the address in array.
+ array(pos) = key
+ array(pos + 1) = address
+ numKeys += 1
+ if (numKeys * 4 > array.length) {
+ // reach half of the capacity
+ growArray()
+ }
+ } else {
+ // there are some values for this key, put the address in the front of them.
+ val pointer = (address >>> 32) + (address & 0xffffffffL)
+ Platform.putLong(page, pointer, array(pos + 1))
+ array(pos + 1) = address
+ }
+ }
+
+ private def growArray(): Unit = {
+ var old_array = array
+ val n = array.length
+ numKeys = 0
+ acquireMemory(n * 2 * 8)
+ array = new Array[Long](n * 2)
+ mask = n * 2 - 2
+ var i = 0
+ while (i < old_array.length) {
+ if (old_array(i + 1) > 0) {
+ updateIndex(old_array(i), old_array(i + 1))
+ }
+ i += 2
+ }
+ old_array = null // release the reference to old array
+ freeMemory(n * 8)
+ }
+
+ /**
+ * Try to turn the map into dense mode, which is faster to probe.
+ */
+ def optimize(): Unit = {
+ val range = maxKey - minKey
+ // Convert to dense mode if it does not require more memory or could fit within L1 cache
+ if (range < array.length || range < 1024) {
+ try {
+ acquireMemory((range + 1) * 8)
+ } catch {
+ case e: SparkException =>
+ // there is no enough memory to convert
+ return
+ }
+ val denseArray = new Array[Long]((range + 1).toInt)
+ var i = 0
+ while (i < array.length) {
+ if (array(i + 1) > 0) {
+ val idx = (array(i) - minKey).toInt
+ denseArray(idx) = array(i + 1)
+ }
+ i += 2
+ }
+ val old_length = array.length
+ array = denseArray
+ isDense = true
+ freeMemory(old_length * 8)
+ }
+ }
+
+ /**
+ * Free all the memory acquired by this map.
+ */
+ def free(): Unit = {
+ if (page != null) {
+ freeMemory(page.length)
+ page = null
+ }
+ if (array != null) {
+ freeMemory(array.length * 8)
+ array = null
+ }
}
override def writeExternal(out: ObjectOutput): Unit = {
- writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ out.writeBoolean(isDense)
+ out.writeLong(minKey)
+ out.writeLong(maxKey)
+ out.writeInt(numKeys)
+ out.writeInt(numValues)
+
+ out.writeInt(array.length)
+ val buffer = new Array[Byte](4 << 10)
+ var offset = Platform.LONG_ARRAY_OFFSET
+ val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET
+ while (offset < end) {
+ val size = Math.min(buffer.length, end - offset)
+ Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
+ out.write(buffer, 0, size)
+ offset += size
+ }
+
+ val used = cursor - Platform.BYTE_ARRAY_OFFSET
+ out.writeInt(used)
+ out.write(page, 0, used)
}
override def readExternal(in: ObjectInput): Unit = {
- hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ isDense = in.readBoolean()
+ minKey = in.readLong()
+ maxKey = in.readLong()
+ numKeys = in.readInt()
+ numValues = in.readInt()
+
+ val length = in.readInt()
+ array = new Array[Long](length)
+ mask = length - 2
+ val buffer = new Array[Byte](4 << 10)
+ var offset = Platform.LONG_ARRAY_OFFSET
+ val end = length * 8 + Platform.LONG_ARRAY_OFFSET
+ while (offset < end) {
+ val size = Math.min(buffer.length, end - offset)
+ in.readFully(buffer, 0, size)
+ Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
+ offset += size
+ }
+
+ val numBytes = in.readInt()
+ page = new Array[Byte](numBytes)
+ in.readFully(page)
}
}
-/**
- * A relation that pack all the rows into a byte array, together with offsets and sizes.
- *
- * All the bytes of UnsafeRow are packed together as `bytes`:
- *
- * [ Row0 ][ Row1 ][] ... [ RowN ]
- *
- * With keys:
- *
- * start start+1 ... start+N
- *
- * `offsets` are offsets of UnsafeRows in the `bytes`
- * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key.
- *
- * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as:
- *
- * start = 3
- * offsets = [0, 0, 24]
- * sizes = [24, 0, 32]
- * bytes = [0 - 24][][24 - 56]
- */
-private[joins] final class LongArrayRelation(
- private var numFields: Int,
- private var start: Long,
- private var offsets: Array[Int],
- private var sizes: Array[Int],
- private var bytes: Array[Byte]
- ) extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+private[joins] class LongHashedRelation(
+ private var nFields: Int,
+ private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable {
+
+ private var resultRow: UnsafeRow = new UnsafeRow(nFields)
// Needed for serialization (it is public to make Java serialization work)
- def this() = this(0, 0L, null, null, null)
+ def this() = this(0, null)
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
- }
+ override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map)
- override def getMemorySize: Long = {
- offsets.length * 4 + sizes.length * 4 + bytes.length
- }
+ override def estimatedSize: Long = map.getTotalMemoryConsumption
- override def getValue(key: Long): InternalRow = {
- val idx = (key - start).toInt
- if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
- val result = new UnsafeRow(numFields)
- result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
- result
+ override def get(key: InternalRow): Iterator[InternalRow] = {
+ if (key.isNullAt(0)) {
+ null
} else {
+ get(key.getLong(0))
+ }
+ }
+
+ override def getValue(key: InternalRow): InternalRow = {
+ if (key.isNullAt(0)) {
null
+ } else {
+ getValue(key.getLong(0))
}
}
+ override def get(key: Long): Iterator[InternalRow] = map.get(key, resultRow)
+
+ override def getValue(key: Long): InternalRow = map.getValue(key, resultRow)
+
+ override def keyIsUnique: Boolean = map.keyIsUnique
+
+ override def close(): Unit = {
+ map.free()
+ }
+
override def writeExternal(out: ObjectOutput): Unit = {
- out.writeInt(numFields)
- out.writeLong(start)
- out.writeInt(sizes.length)
- var i = 0
- while (i < sizes.length) {
- out.writeInt(sizes(i))
- i += 1
- }
- out.writeInt(bytes.length)
- out.write(bytes)
+ out.writeInt(nFields)
+ out.writeObject(map)
}
override def readExternal(in: ObjectInput): Unit = {
- numFields = in.readInt()
- start = in.readLong()
- val length = in.readInt()
- // read sizes of rows
- sizes = new Array[Int](length)
- offsets = new Array[Int](length)
- var i = 0
- var offset = 0
- while (i < length) {
- offsets(i) = offset
- sizes(i) = in.readInt()
- offset += sizes(i)
- i += 1
- }
- // read all the bytes
- val total = in.readInt()
- assert(total == offset)
- bytes = new Array[Byte](total)
- in.readFully(bytes)
+ nFields = in.readInt()
+ resultRow = new UnsafeRow(nFields)
+ map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
}
}
/**
- * Create hashed relation with key that is long.
- */
+ * Create hashed relation with key that is long.
+ */
private[joins] object LongHashedRelation {
-
- val DENSE_FACTOR = 0.2
-
def apply(
- input: Iterator[InternalRow],
- keyGenerator: Projection,
- sizeEstimate: Int): HashedRelation = {
+ input: Iterator[InternalRow],
+ key: Seq[Expression],
+ sizeEstimate: Int,
+ taskMemoryManager: TaskMemoryManager): LongHashedRelation = {
- // Use a Java hash table here because unsafe maps expect fixed size records
- val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
+ val keyGenerator = UnsafeProjection.create(key)
// Create a mapping of key -> rows
var numFields = 0
- var keyIsUnique = true
- var minKey = Long.MaxValue
- var maxKey = Long.MinValue
while (input.hasNext) {
val unsafeRow = input.next().asInstanceOf[UnsafeRow]
numFields = unsafeRow.numFields()
val rowKey = keyGenerator(unsafeRow)
- if (!rowKey.anyNull) {
+ if (!rowKey.isNullAt(0)) {
val key = rowKey.getLong(0)
- minKey = math.min(minKey, key)
- maxKey = math.max(maxKey, key)
- val existingMatchList = hashTable.get(key)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[UnsafeRow]()
- hashTable.put(key, newMatchList)
- newMatchList
- } else {
- keyIsUnique = false
- existingMatchList
- }
- matchList += unsafeRow
+ map.append(key, unsafeRow)
}
}
-
- if (keyIsUnique) {
- if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
- // The keys are dense enough, so use LongArrayRelation
- val length = (maxKey - minKey).toInt + 1
- val sizes = new Array[Int](length)
- val offsets = new Array[Int](length)
- var offset = 0
- var i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- offsets(i) = offset
- sizes(i) = rows(0).getSizeInBytes
- offset += sizes(i)
- }
- i += 1
- }
- val bytes = new Array[Byte](offset)
- i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
- }
- i += 1
- }
- new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
-
- } else {
- // all the keys are unique, one row per key.
- val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size)
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- uniqHashTable.put(entry.getKey, entry.getValue()(0))
- }
- new UniqueLongHashedRelation(uniqHashTable)
- }
- } else {
- new GeneralLongHashedRelation(hashTable)
- }
+ map.optimize()
+ new LongHashedRelation(numFields, map)
}
}
/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */
-private[execution] case class HashedRelationBroadcastMode(
- canJoinKeyFitWithinLong: Boolean,
- keys: Seq[Expression],
- attributes: Seq[Attribute]) extends BroadcastMode {
+private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression])
+ extends BroadcastMode {
override def transform(rows: Array[InternalRow]): HashedRelation = {
- val generator = UnsafeProjection.create(keys, attributes)
- HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length)
+ HashedRelation(rows.iterator, canonicalizedKey, rows.length)
}
- private lazy val canonicalizedKeys: Seq[Expression] = {
- keys.map { e =>
- BindReferences.bindReference(e.canonicalized, attributes)
- }
+ private lazy val canonicalizedKey: Seq[Expression] = {
+ key.map { e => e.canonicalized }
}
override def compatibleWith(other: BroadcastMode): Boolean = other match {
- case m: HashedRelationBroadcastMode =>
- canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong &&
- canonicalizedKeys == m.canonicalizedKeys
+ case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey
case _ => false
}
}