aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala281
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala18
3 files changed, 141 insertions, 160 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 8882903bbf..1f1b5389aa 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -134,7 +134,7 @@ public final class UnsafeFixedWidthAggregationMap {
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
- boolean putSucceeded = loc.putNewKey(
+ boolean putSucceeded = loc.append(
key.getBaseObject(),
key.getBaseOffset(),
key.getSizeInBytes(),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8cc3528639..dc4793e85a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,18 +18,18 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
+import org.apache.spark.util.{KnownSizeEstimation, Utils}
import org.apache.spark.util.collection.CompactBuffer
/**
@@ -54,6 +54,11 @@ private[execution] sealed trait HashedRelation {
*/
def getMemorySize: Long = 1L // to make the test happy
+ /**
+ * Release any used resources.
+ */
+ def close(): Unit = {}
+
// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
@@ -132,163 +137,83 @@ private[execution] object HashedRelation {
}
/**
- * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
- * into a sequence of values.
- *
- * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
- * BytesToBytesMap for better memory performance (multiple values for the same are stored as a
- * continuous byte array.
+ * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap.
*
* It's serialized in the following format:
* [number of keys]
- * [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
- * ...
- *
- * All the values are serialized as following:
- * [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
- * ...
+ * [size of key] [size of value] [key bytes] [bytes for value]
*/
-private[joins] final class UnsafeHashedRelation(
- private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
- extends HashedRelation
- with KnownSizeEstimation
- with Externalizable {
-
- private[joins] def this() = this(null) // Needed for serialization
+private[joins] class UnsafeHashedRelation(
+ private var numFields: Int,
+ private var binaryMap: BytesToBytesMap)
+ extends HashedRelation with KnownSizeEstimation with Externalizable {
- // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
- // This is used in broadcast joins and distributed mode only
- @transient private[this] var binaryMap: BytesToBytesMap = _
+ private[joins] def this() = this(0, null) // Needed for serialization
- /**
- * Return the size of the unsafe map on the executors.
- *
- * For broadcast joins, this hashed relation is bigger on the driver because it is
- * represented as a Java hash map there. While serializing the map to the executors,
- * however, we rehash the contents in a binary map to reduce the memory footprint on
- * the executors.
- *
- * For non-broadcast joins or in local mode, return 0.
- */
override def getMemorySize: Long = {
- if (binaryMap != null) {
- binaryMap.getTotalMemoryConsumption
- } else {
- 0
- }
+ binaryMap.getTotalMemoryConsumption
}
override def estimatedSize: Long = {
- if (binaryMap != null) {
- binaryMap.getTotalMemoryConsumption
- } else {
- SizeEstimator.estimate(hashTable)
- }
+ binaryMap.getTotalMemoryConsumption
}
override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
-
- if (binaryMap != null) {
- // Used in Broadcast join
- val map = binaryMap // avoid the compiler error
- val loc = new map.Location // this could be allocated in stack
- binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
- if (loc.isDefined) {
- val buffer = CompactBuffer[UnsafeRow]()
-
- val base = loc.getValueBase
- var offset = loc.getValueOffset
- val last = offset + loc.getValueLength
- while (offset < last) {
- val numFields = Platform.getInt(base, offset)
- val sizeInBytes = Platform.getInt(base, offset + 4)
- offset += 8
-
- val row = new UnsafeRow(numFields)
- row.pointTo(base, offset, sizeInBytes)
- buffer += row
- offset += sizeInBytes
- }
- buffer
- } else {
- null
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ val buffer = CompactBuffer[UnsafeRow]()
+ val row = new UnsafeRow(numFields)
+ row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ buffer += row
+ while (loc.nextValue()) {
+ val row = new UnsafeRow(numFields)
+ row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ buffer += row
}
-
+ buffer
} else {
- // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
- hashTable.get(unsafeKey)
+ null
}
}
- override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- if (binaryMap != null) {
- // This could happen when a cached broadcast object need to be dumped into disk to free memory
- out.writeInt(binaryMap.numElements())
-
- var buffer = new Array[Byte](64)
- def write(base: Object, offset: Long, length: Int): Unit = {
- if (buffer.length < length) {
- buffer = new Array[Byte](length)
- }
- Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
- out.write(buffer, 0, length)
- }
+ override def close(): Unit = {
+ binaryMap.free()
+ }
- val iter = binaryMap.iterator()
- while (iter.hasNext) {
- val loc = iter.next()
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(loc.getKeyLength)
- out.writeInt(loc.getValueLength)
- write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
- write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+ out.writeInt(numFields)
+ // TODO: move these into BytesToBytesMap
+ out.writeInt(binaryMap.numKeys())
+ out.writeInt(binaryMap.numValues())
+
+ var buffer = new Array[Byte](64)
+ def write(base: Object, offset: Long, length: Int): Unit = {
+ if (buffer.length < length) {
+ buffer = new Array[Byte](length)
}
+ Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
+ out.write(buffer, 0, length)
+ }
- } else {
- assert(hashTable != null)
- out.writeInt(hashTable.size())
-
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val key = entry.getKey
- val values = entry.getValue
-
- // write all the values as single byte array
- var totalSize = 0L
- var i = 0
- while (i < values.length) {
- totalSize += values(i).getSizeInBytes + 4 + 4
- i += 1
- }
- assert(totalSize < Integer.MAX_VALUE, "values are too big")
-
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(key.getSizeInBytes)
- out.writeInt(totalSize.toInt)
- out.write(key.getBytes)
- i = 0
- while (i < values.length) {
- // [num of fields] [num of bytes] [row bytes]
- // write the integer in native order, so they can be read by UNSAFE.getInt()
- if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
- out.writeInt(values(i).numFields())
- out.writeInt(values(i).getSizeInBytes)
- } else {
- out.writeInt(Integer.reverseBytes(values(i).numFields()))
- out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
- }
- out.write(values(i).getBytes)
- i += 1
- }
- }
+ val iter = binaryMap.iterator()
+ while (iter.hasNext) {
+ val loc = iter.next()
+ // [key size] [values size] [key bytes] [value bytes]
+ out.writeInt(loc.getKeyLength)
+ out.writeInt(loc.getValueLength)
+ write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+ write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
}
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+ numFields = in.readInt()
val nKeys = in.readInt()
+ val nValues = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
// TODO(josh): This needs to be revisited before we merge this patch; making this change now
// so that tests compile:
@@ -314,7 +239,7 @@ private[joins] final class UnsafeHashedRelation(
var i = 0
var keyBuffer = new Array[Byte](1024)
var valuesBuffer = new Array[Byte](1024)
- while (i < nKeys) {
+ while (i < nValues) {
val keySize = in.readInt()
val valuesSize = in.readInt()
if (keySize > keyBuffer.length) {
@@ -326,13 +251,11 @@ private[joins] final class UnsafeHashedRelation(
}
in.readFully(valuesBuffer, 0, valuesSize)
- // put it into binary map
val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
- assert(!loc.isDefined, "Duplicated key found!")
- val putSuceeded = loc.putNewKey(
- keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
+ val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize)
if (!putSuceeded) {
+ binaryMap.free()
throw new IOException("Could not allocate memory to grow BytesToBytesMap")
}
i += 1
@@ -340,6 +263,29 @@ private[joins] final class UnsafeHashedRelation(
}
}
+/**
+ * A HashedRelation for UnsafeRow with unique keys.
+ */
+private[joins] final class UniqueUnsafeHashedRelation(
+ private var numFields: Int,
+ private var binaryMap: BytesToBytesMap)
+ extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation {
+ def getValue(key: InternalRow): InternalRow = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ val row = new UnsafeRow(numFields)
+ row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ row
+ } else {
+ null
+ }
+ }
+}
+
private[joins] object UnsafeHashedRelation {
def apply(
@@ -347,29 +293,54 @@ private[joins] object UnsafeHashedRelation {
keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {
- // Use a Java hash table here because unsafe maps expect fixed size records
- // TODO: Use BytesToBytesMap for memory efficiency
- val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val taskMemoryManager = if (TaskContext.get() != null) {
+ TaskContext.get().taskMemoryManager()
+ } else {
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ }
+ val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
+ .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
+
+ val binaryMap = new BytesToBytesMap(
+ taskMemoryManager,
+ // Only 70% of the slots can be used before growing, more capacity help to reduce collision
+ (sizeEstimate * 1.5 + 1).toInt,
+ pageSizeBytes)
// Create a mapping of buildKeys -> rows
+ var numFields = 0
+ // Whether all the keys are unique or not
+ var allUnique: Boolean = true
while (input.hasNext) {
- val unsafeRow = input.next().asInstanceOf[UnsafeRow]
- val rowKey = keyGenerator(unsafeRow)
- if (!rowKey.anyNull) {
- val existingMatchList = hashTable.get(rowKey)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[UnsafeRow]()
- hashTable.put(rowKey.copy(), newMatchList)
- newMatchList
- } else {
- existingMatchList
+ val row = input.next().asInstanceOf[UnsafeRow]
+ numFields = row.numFields()
+ val key = keyGenerator(row)
+ if (!key.anyNull) {
+ val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ if (loc.isDefined) {
+ allUnique = false
+ }
+ val success = loc.append(
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
+ if (!success) {
+ binaryMap.free()
+ throw new SparkException("There is no enough memory to build hash map")
}
- matchList += unsafeRow
}
}
- // TODO: create UniqueUnsafeRelation
- new UnsafeHashedRelation(hashTable)
+ if (allUnique) {
+ new UniqueUnsafeHashedRelation(numFields, binaryMap)
+ } else {
+ new UnsafeHashedRelation(numFields, binaryMap)
+ }
}
}
@@ -523,7 +494,7 @@ private[joins] object LongHashedRelation {
keyGenerator: Projection,
sizeEstimate: Int): HashedRelation = {
- // Use a Java hash table here because unsafe maps expect fixed size records
+ // TODO: use LongToBytesMap for better memory efficiency
val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
// Create a mapping of key -> rows
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5c4f1ef60f..e3a2eaea5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -57,9 +57,19 @@ case class ShuffledHashJoin(
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
+ val context = TaskContext.get()
+ if (!canJoinKeyFitWithinLong) {
+ // build BytesToBytesMap
+ val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator)
+ // This relation is usually used until the end of task.
+ context.addTaskCompletionListener((t: TaskContext) =>
+ relation.close()
+ )
+ return relation
+ }
+
// try to acquire some memory for the hash table, it could trigger other operator to free some
// memory. The memory acquired here will mostly be used until the end of task.
- val context = TaskContext.get()
val memoryManager = context.taskMemoryManager()
var acquired = 0L
var used = 0L
@@ -69,18 +79,18 @@ case class ShuffledHashJoin(
val copiedIter = iter.map { row =>
// It's hard to guess what's exactly memory will be used, we have a rough guess here.
- // TODO: use BytesToBytesMap instead of HashMap for memory efficiency
- // Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers
+ // TODO: use LongToBytesMap instead of HashMap for memory efficiency
+ // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers
val needed = 150 + row.getSizeInBytes
if (needed > acquired - used) {
val got = memoryManager.acquireExecutionMemory(
Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
+ acquired += got
if (got < needed) {
throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
"hash join, please use sort merge join by setting " +
"spark.sql.join.preferSortMergeJoin=true")
}
- acquired += got
}
used += needed
// HashedRelation requires that the UnsafeRow should be separate objects.