aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-28 13:07:32 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-28 13:07:32 -0700
commitd7b58f1461f71ee3c028360eef0ffedd17d6a076 (patch)
tree58ddca8bb29534ecb77446e6706f33d885e01bd4 /sql/core/src/main
parent600c0b69cab4767e8e5a6f4284777d8b9d4bd40e (diff)
downloadspark-d7b58f1461f71ee3c028360eef0ffedd17d6a076.tar.gz
spark-d7b58f1461f71ee3c028360eef0ffedd17d6a076.tar.bz2
spark-d7b58f1461f71ee3c028360eef0ffedd17d6a076.zip
[SPARK-14052] [SQL] build a BytesToBytesMap directly in HashedRelation
## What changes were proposed in this pull request? Currently, for the key that can not fit within a long, we build a hash map for UnsafeHashedRelation, it's converted to BytesToBytesMap after serialization and deserialization. We should build a BytesToBytesMap directly to have better memory efficiency. In order to do that, BytesToBytesMap should support multiple (K,V) pair with the same K, Location.putNewKey() is renamed to Location.append(), which could append multiple values for the same key (same Location). `Location.newValue()` is added to find the next value for the same key. ## How was this patch tested? Existing tests. Added benchmark for broadcast hash join with duplicated keys. Author: Davies Liu <davies@databricks.com> Closes #11870 from davies/map2.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala281
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala18
3 files changed, 141 insertions, 160 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 8882903bbf..1f1b5389aa 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -134,7 +134,7 @@ public final class UnsafeFixedWidthAggregationMap {
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
- boolean putSucceeded = loc.putNewKey(
+ boolean putSucceeded = loc.append(
key.getBaseObject(),
key.getBaseOffset(),
key.getSizeInBytes(),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8cc3528639..dc4793e85a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,18 +18,18 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
+import org.apache.spark.util.{KnownSizeEstimation, Utils}
import org.apache.spark.util.collection.CompactBuffer
/**
@@ -54,6 +54,11 @@ private[execution] sealed trait HashedRelation {
*/
def getMemorySize: Long = 1L // to make the test happy
+ /**
+ * Release any used resources.
+ */
+ def close(): Unit = {}
+
// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
@@ -132,163 +137,83 @@ private[execution] object HashedRelation {
}
/**
- * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
- * into a sequence of values.
- *
- * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
- * BytesToBytesMap for better memory performance (multiple values for the same are stored as a
- * continuous byte array.
+ * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap.
*
* It's serialized in the following format:
* [number of keys]
- * [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
- * ...
- *
- * All the values are serialized as following:
- * [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
- * ...
+ * [size of key] [size of value] [key bytes] [bytes for value]
*/
-private[joins] final class UnsafeHashedRelation(
- private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
- extends HashedRelation
- with KnownSizeEstimation
- with Externalizable {
-
- private[joins] def this() = this(null) // Needed for serialization
+private[joins] class UnsafeHashedRelation(
+ private var numFields: Int,
+ private var binaryMap: BytesToBytesMap)
+ extends HashedRelation with KnownSizeEstimation with Externalizable {
- // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
- // This is used in broadcast joins and distributed mode only
- @transient private[this] var binaryMap: BytesToBytesMap = _
+ private[joins] def this() = this(0, null) // Needed for serialization
- /**
- * Return the size of the unsafe map on the executors.
- *
- * For broadcast joins, this hashed relation is bigger on the driver because it is
- * represented as a Java hash map there. While serializing the map to the executors,
- * however, we rehash the contents in a binary map to reduce the memory footprint on
- * the executors.
- *
- * For non-broadcast joins or in local mode, return 0.
- */
override def getMemorySize: Long = {
- if (binaryMap != null) {
- binaryMap.getTotalMemoryConsumption
- } else {
- 0
- }
+ binaryMap.getTotalMemoryConsumption
}
override def estimatedSize: Long = {
- if (binaryMap != null) {
- binaryMap.getTotalMemoryConsumption
- } else {
- SizeEstimator.estimate(hashTable)
- }
+ binaryMap.getTotalMemoryConsumption
}
override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
-
- if (binaryMap != null) {
- // Used in Broadcast join
- val map = binaryMap // avoid the compiler error
- val loc = new map.Location // this could be allocated in stack
- binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
- if (loc.isDefined) {
- val buffer = CompactBuffer[UnsafeRow]()
-
- val base = loc.getValueBase
- var offset = loc.getValueOffset
- val last = offset + loc.getValueLength
- while (offset < last) {
- val numFields = Platform.getInt(base, offset)
- val sizeInBytes = Platform.getInt(base, offset + 4)
- offset += 8
-
- val row = new UnsafeRow(numFields)
- row.pointTo(base, offset, sizeInBytes)
- buffer += row
- offset += sizeInBytes
- }
- buffer
- } else {
- null
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ val buffer = CompactBuffer[UnsafeRow]()
+ val row = new UnsafeRow(numFields)
+ row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ buffer += row
+ while (loc.nextValue()) {
+ val row = new UnsafeRow(numFields)
+ row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ buffer += row
}
-
+ buffer
} else {
- // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
- hashTable.get(unsafeKey)
+ null
}
}
- override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- if (binaryMap != null) {
- // This could happen when a cached broadcast object need to be dumped into disk to free memory
- out.writeInt(binaryMap.numElements())
-
- var buffer = new Array[Byte](64)
- def write(base: Object, offset: Long, length: Int): Unit = {
- if (buffer.length < length) {
- buffer = new Array[Byte](length)
- }
- Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
- out.write(buffer, 0, length)
- }
+ override def close(): Unit = {
+ binaryMap.free()
+ }
- val iter = binaryMap.iterator()
- while (iter.hasNext) {
- val loc = iter.next()
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(loc.getKeyLength)
- out.writeInt(loc.getValueLength)
- write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
- write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+ out.writeInt(numFields)
+ // TODO: move these into BytesToBytesMap
+ out.writeInt(binaryMap.numKeys())
+ out.writeInt(binaryMap.numValues())
+
+ var buffer = new Array[Byte](64)
+ def write(base: Object, offset: Long, length: Int): Unit = {
+ if (buffer.length < length) {
+ buffer = new Array[Byte](length)
}
+ Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
+ out.write(buffer, 0, length)
+ }
- } else {
- assert(hashTable != null)
- out.writeInt(hashTable.size())
-
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val key = entry.getKey
- val values = entry.getValue
-
- // write all the values as single byte array
- var totalSize = 0L
- var i = 0
- while (i < values.length) {
- totalSize += values(i).getSizeInBytes + 4 + 4
- i += 1
- }
- assert(totalSize < Integer.MAX_VALUE, "values are too big")
-
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(key.getSizeInBytes)
- out.writeInt(totalSize.toInt)
- out.write(key.getBytes)
- i = 0
- while (i < values.length) {
- // [num of fields] [num of bytes] [row bytes]
- // write the integer in native order, so they can be read by UNSAFE.getInt()
- if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
- out.writeInt(values(i).numFields())
- out.writeInt(values(i).getSizeInBytes)
- } else {
- out.writeInt(Integer.reverseBytes(values(i).numFields()))
- out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
- }
- out.write(values(i).getBytes)
- i += 1
- }
- }
+ val iter = binaryMap.iterator()
+ while (iter.hasNext) {
+ val loc = iter.next()
+ // [key size] [values size] [key bytes] [value bytes]
+ out.writeInt(loc.getKeyLength)
+ out.writeInt(loc.getValueLength)
+ write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+ write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
}
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+ numFields = in.readInt()
val nKeys = in.readInt()
+ val nValues = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
// TODO(josh): This needs to be revisited before we merge this patch; making this change now
// so that tests compile:
@@ -314,7 +239,7 @@ private[joins] final class UnsafeHashedRelation(
var i = 0
var keyBuffer = new Array[Byte](1024)
var valuesBuffer = new Array[Byte](1024)
- while (i < nKeys) {
+ while (i < nValues) {
val keySize = in.readInt()
val valuesSize = in.readInt()
if (keySize > keyBuffer.length) {
@@ -326,13 +251,11 @@ private[joins] final class UnsafeHashedRelation(
}
in.readFully(valuesBuffer, 0, valuesSize)
- // put it into binary map
val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
- assert(!loc.isDefined, "Duplicated key found!")
- val putSuceeded = loc.putNewKey(
- keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
+ val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize)
if (!putSuceeded) {
+ binaryMap.free()
throw new IOException("Could not allocate memory to grow BytesToBytesMap")
}
i += 1
@@ -340,6 +263,29 @@ private[joins] final class UnsafeHashedRelation(
}
}
+/**
+ * A HashedRelation for UnsafeRow with unique keys.
+ */
+private[joins] final class UniqueUnsafeHashedRelation(
+ private var numFields: Int,
+ private var binaryMap: BytesToBytesMap)
+ extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation {
+ def getValue(key: InternalRow): InternalRow = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ val row = new UnsafeRow(numFields)
+ row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ row
+ } else {
+ null
+ }
+ }
+}
+
private[joins] object UnsafeHashedRelation {
def apply(
@@ -347,29 +293,54 @@ private[joins] object UnsafeHashedRelation {
keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {
- // Use a Java hash table here because unsafe maps expect fixed size records
- // TODO: Use BytesToBytesMap for memory efficiency
- val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val taskMemoryManager = if (TaskContext.get() != null) {
+ TaskContext.get().taskMemoryManager()
+ } else {
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ }
+ val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
+ .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
+
+ val binaryMap = new BytesToBytesMap(
+ taskMemoryManager,
+ // Only 70% of the slots can be used before growing, more capacity help to reduce collision
+ (sizeEstimate * 1.5 + 1).toInt,
+ pageSizeBytes)
// Create a mapping of buildKeys -> rows
+ var numFields = 0
+ // Whether all the keys are unique or not
+ var allUnique: Boolean = true
while (input.hasNext) {
- val unsafeRow = input.next().asInstanceOf[UnsafeRow]
- val rowKey = keyGenerator(unsafeRow)
- if (!rowKey.anyNull) {
- val existingMatchList = hashTable.get(rowKey)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[UnsafeRow]()
- hashTable.put(rowKey.copy(), newMatchList)
- newMatchList
- } else {
- existingMatchList
+ val row = input.next().asInstanceOf[UnsafeRow]
+ numFields = row.numFields()
+ val key = keyGenerator(row)
+ if (!key.anyNull) {
+ val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ if (loc.isDefined) {
+ allUnique = false
+ }
+ val success = loc.append(
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
+ if (!success) {
+ binaryMap.free()
+ throw new SparkException("There is no enough memory to build hash map")
}
- matchList += unsafeRow
}
}
- // TODO: create UniqueUnsafeRelation
- new UnsafeHashedRelation(hashTable)
+ if (allUnique) {
+ new UniqueUnsafeHashedRelation(numFields, binaryMap)
+ } else {
+ new UnsafeHashedRelation(numFields, binaryMap)
+ }
}
}
@@ -523,7 +494,7 @@ private[joins] object LongHashedRelation {
keyGenerator: Projection,
sizeEstimate: Int): HashedRelation = {
- // Use a Java hash table here because unsafe maps expect fixed size records
+ // TODO: use LongToBytesMap for better memory efficiency
val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
// Create a mapping of key -> rows
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5c4f1ef60f..e3a2eaea5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -57,9 +57,19 @@ case class ShuffledHashJoin(
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
+ val context = TaskContext.get()
+ if (!canJoinKeyFitWithinLong) {
+ // build BytesToBytesMap
+ val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator)
+ // This relation is usually used until the end of task.
+ context.addTaskCompletionListener((t: TaskContext) =>
+ relation.close()
+ )
+ return relation
+ }
+
// try to acquire some memory for the hash table, it could trigger other operator to free some
// memory. The memory acquired here will mostly be used until the end of task.
- val context = TaskContext.get()
val memoryManager = context.taskMemoryManager()
var acquired = 0L
var used = 0L
@@ -69,18 +79,18 @@ case class ShuffledHashJoin(
val copiedIter = iter.map { row =>
// It's hard to guess what's exactly memory will be used, we have a rough guess here.
- // TODO: use BytesToBytesMap instead of HashMap for memory efficiency
- // Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers
+ // TODO: use LongToBytesMap instead of HashMap for memory efficiency
+ // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers
val needed = 150 + row.getSizeInBytes
if (needed > acquired - used) {
val got = memoryManager.acquireExecutionMemory(
Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
+ acquired += got
if (got < needed) {
throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
"hash join, please use sort merge join by setting " +
"spark.sql.join.preferSortMergeJoin=true")
}
- acquired += got
}
used += needed
// HashedRelation requires that the UnsafeRow should be separate objects.