aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-09 00:37:55 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-09 00:37:55 -0700
commit90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f (patch)
treebb0c38896d7b02e4dd612a68f22bcc982d383d08 /sql/core
parent520dde48d0d52dbbbbe1710a3275fdd5355dd69d (diff)
downloadspark-90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f.tar.gz
spark-90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f.tar.bz2
spark-90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f.zip
[SPARK-14419] [SQL] Improve HashedRelation for key fit within Long
## What changes were proposed in this pull request? Currently, we use java HashMap for HashedRelation if the key could fit within a Long. The java HashMap and CompactBuffer are not memory efficient, the memory used by them is also accounted accurately. This PR introduce a LongToUnsafeRowMap (similar to BytesToBytesMap) for better memory efficiency and performance. ## How was this patch tested? Updated existing tests. Author: Davies Liu <davies@databricks.com> Closes #12190 from davies/long_map2.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala688
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala51
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala132
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala48
8 files changed, 633 insertions, 346 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 0a5a72c52a..692fef703f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -454,7 +454,7 @@ case class TungstenAggregate(
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
- ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+ ctx.addMutableState(hashMapClassName, hashMapTerm, s"")
sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
@@ -467,6 +467,7 @@ case class TungstenAggregate(
s"""
${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
private void $doAgg() throws java.io.IOException {
+ $hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index e3d554c2de..a8f854136c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.LongType
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
@@ -50,10 +51,7 @@ case class BroadcastHashJoin(
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
override def requiredChildDistribution: Seq[Distribution] = {
- val mode = HashedRelationBroadcastMode(
- canJoinKeyFitWithinLong,
- rewriteKeyExpr(buildKeys),
- buildPlan.output)
+ val mode = HashedRelationBroadcastMode(buildKeys)
buildSide match {
case BuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
@@ -68,7 +66,7 @@ case class BroadcastHashJoin(
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
val hashed = broadcastRelation.value.asReadOnlyCopy()
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize)
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
join(streamedIter, hashed, numOutputRows)
}
}
@@ -105,7 +103,7 @@ case class BroadcastHashJoin(
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
- | incPeakExecutionMemory($relationTerm.getMemorySize());
+ | incPeakExecutionMemory($relationTerm.estimatedSize());
""".stripMargin)
(broadcastRelation, relationTerm)
}
@@ -118,15 +116,13 @@ case class BroadcastHashJoin(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
- if (canJoinKeyFitWithinLong) {
+ if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
// generate the join key as Long
- val expr = rewriteKeyExpr(streamedKeys).head
- val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
+ val ev = streamedKeys.head.gen(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
- val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
- val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
(ev, s"${ev.value}.anyNull()")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 8f45d57126..4c912d371e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -59,9 +59,13 @@ trait HashJoin {
case BuildRight => (right, left)
}
- protected lazy val (buildKeys, streamedKeys) = buildSide match {
- case BuildLeft => (leftKeys, rightKeys)
- case BuildRight => (rightKeys, leftKeys)
+ protected lazy val (buildKeys, streamedKeys) = {
+ val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
+ val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output))
+ buildSide match {
+ case BuildLeft => (lkeys, rkeys)
+ case BuildRight => (rkeys, lkeys)
+ }
}
/**
@@ -84,17 +88,8 @@ trait HashJoin {
width = dt.defaultSize
} else {
val bits = dt.defaultSize * 8
- // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
- // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
- // with two same ints have hash code 0, we rotate the bits of second one.
- val rotated = if (e.dataType == IntegerType) {
- // (e >>> 15) | (e << 17)
- BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
- } else {
- e
- }
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
- BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1)))
+ BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
width -= bits
}
// TODO: support BooleanType, DateType and TimestampType
@@ -105,17 +100,11 @@ trait HashJoin {
keyExpr :: Nil
}
- protected lazy val canJoinKeyFitWithinLong: Boolean = {
- val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
- val key = rewriteKeyExpr(buildKeys)
- sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
- }
-
protected def buildSideKeyGenerator(): Projection =
- UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
+ UnsafeProjection.create(buildKeys)
protected def streamSideKeyGenerator(): UnsafeProjection =
- UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
+ UnsafeProjection.create(streamedKeys)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 5ccb435686..4959f60dab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,24 +18,22 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
-import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException}
+import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.types.LongType
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.{KnownSizeEstimation, Utils}
-import org.apache.spark.util.collection.CompactBuffer
/**
* Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
* object.
*/
-private[execution] sealed trait HashedRelation {
+private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
/**
* Returns matched rows.
*
@@ -75,50 +73,35 @@ private[execution] sealed trait HashedRelation {
def asReadOnlyCopy(): HashedRelation
/**
- * Returns the size of used memory.
- */
- def getMemorySize: Long = 1L // to make the test happy
-
- /**
* Release any used resources.
*/
- def close(): Unit = {}
-
- // This is a helper method to implement Externalizable, and is used by
- // GeneralHashedRelation and UniqueKeyHashedRelation
- protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
- out.writeInt(serialized.length) // Write the length of serialized bytes first
- out.write(serialized)
- }
-
- // This is a helper method to implement Externalizable, and is used by
- // GeneralHashedRelation and UniqueKeyHashedRelation
- protected def readBytes(in: ObjectInput): Array[Byte] = {
- val serializedSize = in.readInt() // Read the length of serialized bytes first
- val bytes = new Array[Byte](serializedSize)
- in.readFully(bytes)
- bytes
- }
+ def close(): Unit
}
private[execution] object HashedRelation {
/**
* Create a HashedRelation from an Iterator of InternalRow.
- *
- * Note: The caller should make sure that these InternalRow are different objects.
*/
def apply(
- canJoinKeyFitWithinLong: Boolean,
input: Iterator[InternalRow],
- keyGenerator: Projection,
- sizeEstimate: Int = 64): HashedRelation = {
+ key: Seq[Expression],
+ sizeEstimate: Int = 64,
+ taskMemoryManager: TaskMemoryManager = null): HashedRelation = {
+ val mm = Option(taskMemoryManager).getOrElse {
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ }
- if (canJoinKeyFitWithinLong) {
- LongHashedRelation(input, keyGenerator, sizeEstimate)
+ if (key.length == 1 && key.head.dataType == LongType) {
+ LongHashedRelation(input, key, sizeEstimate, mm)
} else {
- UnsafeHashedRelation(
- input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
+ UnsafeHashedRelation(input, key, sizeEstimate, mm)
}
}
}
@@ -133,7 +116,7 @@ private[execution] object HashedRelation {
private[joins] class UnsafeHashedRelation(
private var numFields: Int,
private var binaryMap: BytesToBytesMap)
- extends HashedRelation with KnownSizeEstimation with Externalizable {
+ extends HashedRelation with Externalizable {
private[joins] def this() = this(0, null) // Needed for serialization
@@ -142,10 +125,6 @@ private[joins] class UnsafeHashedRelation(
override def asReadOnlyCopy(): UnsafeHashedRelation =
new UnsafeHashedRelation(numFields, binaryMap)
- override def getMemorySize: Long = {
- binaryMap.getTotalMemoryConsumption
- }
-
override def estimatedSize: Long = {
binaryMap.getTotalMemoryConsumption
}
@@ -276,20 +255,10 @@ private[joins] object UnsafeHashedRelation {
def apply(
input: Iterator[InternalRow],
- keyGenerator: UnsafeProjection,
- sizeEstimate: Int): HashedRelation = {
+ key: Seq[Expression],
+ sizeEstimate: Int,
+ taskMemoryManager: TaskMemoryManager): HashedRelation = {
- val taskMemoryManager = if (TaskContext.get() != null) {
- TaskContext.get().taskMemoryManager()
- } else {
- new TaskMemoryManager(
- new StaticMemoryManager(
- new SparkConf().set("spark.memory.offHeap.enabled", "false"),
- Long.MaxValue,
- Long.MaxValue,
- 1),
- 0)
- }
val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
.getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
@@ -300,6 +269,7 @@ private[joins] object UnsafeHashedRelation {
pageSizeBytes)
// Create a mapping of buildKeys -> rows
+ val keyGenerator = UnsafeProjection.create(key)
var numFields = 0
while (input.hasNext) {
val row = input.next().asInstanceOf[UnsafeRow]
@@ -321,144 +291,471 @@ private[joins] object UnsafeHashedRelation {
}
}
+private[joins] object LongToUnsafeRowMap {
+ // the largest prime that below 2^n
+ val LARGEST_PRIMES = {
+ // https://primes.utm.edu/lists/2small/0bit.html
+ val diffs = Seq(
+ 0, 1, 1, 3, 1, 3, 1, 5,
+ 3, 3, 9, 3, 1, 3, 19, 15,
+ 1, 5, 1, 3, 9, 3, 15, 3,
+ 39, 5, 39, 57, 3, 35, 1, 5
+ )
+ val primes = new Array[Int](32)
+ primes(0) = 1
+ var power2 = 1
+ (1 until 32).foreach { i =>
+ power2 *= 2
+ primes(i) = power2 - diffs(i)
+ }
+ primes
+ }
+}
+
/**
- * An interface for a hashed relation that the key is a Long.
+ * An append-only hash map mapping from key of Long to UnsafeRow.
+ *
+ * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array
+ * (`page`) in this format:
+ *
+ * [bytes of row1][address1][bytes of row2][address1] ...
+ *
+ * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key
+ * could have multiple values. the address at the end of last value for every key is 0.
+ *
+ * The keys and addresses of their values could be stored in two modes:
+ *
+ * 1) sparse mode: the keys and addresses are stored in `array` as:
+ *
+ * [key1][address1][key2][address2]...[]
+ *
+ * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1
+ * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address
+ * hash collision.
+ *
+ * 2) dense mode: all the addresses are packed into a single array of long, as:
+ *
+ * [address1] [address2] ...
+ *
+ * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is
+ * determined by `key1 - minKey`.
+ *
+ * The map is created as sparse mode, then key-value could be appended into it. Once finish
+ * appending, caller could all optimize() to try to turn the map into dense mode, which is faster
+ * to probe.
*/
-private[joins] trait LongHashedRelation extends HashedRelation {
- override def get(key: InternalRow): Iterator[InternalRow] = {
- get(key.getLong(0))
+private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int)
+ extends MemoryConsumer(mm) with Externalizable {
+ import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._
+
+ // Whether the keys are stored in dense mode or not.
+ private var isDense = false
+
+ // The minimum value of keys.
+ private var minKey = Long.MaxValue
+
+ // The Maxinum value of keys.
+ private var maxKey = Long.MinValue
+
+ // Sparse mode: the actual capacity of map, is a prime number.
+ private var cap: Int = 0
+
+ // The array to store the key and offset of UnsafeRow in the page.
+ //
+ // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ...
+ // Dense mode: [offset1 | size1] [offset2 | size2]
+ private var array: Array[Long] = null
+
+ // The page to store all bytes of UnsafeRow and the pointer to next rows.
+ // [row1][pointer1] [row2][pointer2]
+ private var page: Array[Byte] = null
+
+ // Current write cursor in the page.
+ private var cursor = Platform.BYTE_ARRAY_OFFSET
+
+ // The total number of values of all keys.
+ private var numValues = 0
+
+ // The number of unique keys.
+ private var numKeys = 0
+
+ // needed by serializer
+ def this() = {
+ this(
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0),
+ 0)
}
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
+
+ private def acquireMemory(size: Long): Unit = {
+ // do not support spilling
+ val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this)
+ if (got < size) {
+ mm.releaseExecutionMemory(got, MemoryMode.ON_HEAP, this)
+ throw new SparkException(s"Can't acquire $size bytes memory to build hash relation")
+ }
}
-}
-private[joins] final class GeneralLongHashedRelation(
- private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]])
- extends LongHashedRelation with Externalizable {
+ private def freeMemory(size: Long): Unit = {
+ mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this)
+ }
- // Needed for serialization (it is public to make Java serialization work)
- def this() = this(null)
+ private def init(): Unit = {
+ if (mm != null) {
+ cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{
+ sys.error(s"Can't create map with capacity $capacity")
+ }
+ acquireMemory(cap * 2 * 8 + (1 << 20))
+ array = new Array[Long](cap * 2)
+ page = new Array[Byte](1 << 20) // 1M bytes
+ }
+ }
+
+ init()
+
+ def spill(size: Long, trigger: MemoryConsumer): Long = {
+ 0L
+ }
+
+ /**
+ * Returns whether all the keys are unique.
+ */
+ def keyIsUnique: Boolean = numKeys == numValues
+
+ /**
+ * Returns total memory consumption.
+ */
+ def getTotalMemoryConsumption: Long = {
+ array.length * 8 + page.length
+ }
- override def keyIsUnique: Boolean = false
+ /**
+ * Returns the slot of array that store the keys (sparse mode).
+ */
+ private def getSlot(key: Long): Int = {
+ var s = (key % cap).toInt
+ if (s < 0) {
+ s += cap
+ }
+ s * 2
+ }
- override def asReadOnlyCopy(): GeneralLongHashedRelation =
- new GeneralLongHashedRelation(hashTable)
+ private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
+ val offset = address >>> 32
+ val size = address & 0xffffffffL
+ resultRow.pointTo(page, offset, size.toInt)
+ resultRow
+ }
- override def get(key: Long): Iterator[InternalRow] = {
- val rows = hashTable.get(key)
- if (rows != null) {
- rows.toIterator
+ /**
+ * Returns the single UnsafeRow for given key, or null if not found.
+ */
+ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
+ if (isDense) {
+ val idx = (key - minKey).toInt
+ if (idx >= 0 && key <= maxKey && array(idx) > 0) {
+ return getRow(array(idx), resultRow)
+ }
} else {
- null
+ var pos = getSlot(key)
+ var step = 1
+ while (array(pos + 1) != 0) {
+ if (array(pos) == key) {
+ return getRow(array(pos + 1), resultRow)
+ }
+ pos += 2 * step
+ step += 1
+ if (pos >= array.length) {
+ pos -= array.length
+ }
+ }
+ }
+ null
+ }
+
+ /**
+ * Returns an interator of UnsafeRow for multiple linked values.
+ */
+ private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
+ new Iterator[UnsafeRow] {
+ var addr = address
+ override def hasNext: Boolean = addr != 0
+ override def next(): UnsafeRow = {
+ val offset = addr >>> 32
+ val size = addr & 0xffffffffL
+ resultRow.pointTo(page, offset, size.toInt)
+ addr = Platform.getLong(page, offset + size)
+ resultRow
+ }
+ }
+ }
+
+ /**
+ * Returns an iterator for all the values for the given key, or null if no value found.
+ */
+ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
+ if (isDense) {
+ val idx = (key - minKey).toInt
+ if (idx >=0 && key <= maxKey && array(idx) > 0) {
+ return valueIter(array(idx), resultRow)
+ }
+ } else {
+ var pos = getSlot(key)
+ var step = 1
+ while (array(pos + 1) != 0) {
+ if (array(pos) == key) {
+ return valueIter(array(pos + 1), resultRow)
+ }
+ pos += 2 * step
+ step += 1
+ if (pos >= array.length) {
+ pos -= array.length
+ }
+ }
+ }
+ null
+ }
+
+ /**
+ * Appends the key and row into this map.
+ */
+ def append(key: Long, row: UnsafeRow): Unit = {
+ if (key < minKey) {
+ minKey = key
+ }
+ if (key > maxKey) {
+ maxKey = key
+ }
+
+ // There is 8 bytes for the pointer to next value
+ if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) {
+ val used = page.length
+ if (used * 2L > (1L << 31)) {
+ sys.error("Can't allocate a page that is larger than 2G")
+ }
+ acquireMemory(used * 2)
+ val newPage = new Array[Byte](used * 2)
+ System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET)
+ page = newPage
+ freeMemory(used)
+ }
+
+ // copy the bytes of UnsafeRow
+ val offset = cursor
+ Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes)
+ cursor += row.getSizeInBytes
+ Platform.putLong(page, cursor, 0)
+ cursor += 8
+ numValues += 1
+ updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes)
+ }
+
+ /**
+ * Update the address in array for given key.
+ */
+ private def updateIndex(key: Long, address: Long): Unit = {
+ var pos = getSlot(key)
+ var step = 1
+ while (array(pos + 1) != 0 && array(pos) != key) {
+ pos += 2 * step
+ step += 1
+ if (pos >= array.length) {
+ pos -= array.length
+ }
+ }
+ if (array(pos + 1) == 0) {
+ // this is the first value for this key, put the address in array.
+ array(pos) = key
+ array(pos + 1) = address
+ numKeys += 1
+ if (numKeys * 2 > cap) {
+ // reach half of the capacity
+ growArray()
+ }
+ } else {
+ // there is another value for this key, put the address at the end of final value.
+ var addr = array(pos + 1)
+ var pointer = (addr >>> 32) + (addr & 0xffffffffL)
+ while (Platform.getLong(page, pointer) != 0) {
+ addr = Platform.getLong(page, pointer)
+ pointer = (addr >>> 32) + (addr & 0xffffffffL)
+ }
+ Platform.putLong(page, pointer, address)
+ }
+ }
+
+ private def growArray(): Unit = {
+ val old_cap = cap
+ var old_array = array
+ cap = LARGEST_PRIMES.find(_ > cap).getOrElse{
+ sys.error(s"Can't grow map any more than $cap")
+ }
+ numKeys = 0
+ acquireMemory(cap * 2 * 8)
+ array = new Array[Long](cap * 2)
+ var i = 0
+ while (i < old_array.length) {
+ if (old_array(i + 1) > 0) {
+ updateIndex(old_array(i), old_array(i + 1))
+ }
+ i += 2
+ }
+ old_array = null // release the reference to old array
+ freeMemory(old_cap * 2 * 8)
+ }
+
+ /**
+ * Try to turn the map into dense mode, which is faster to probe.
+ */
+ def optimize(): Unit = {
+ val range = maxKey - minKey
+ // Convert to dense mode if it does not require more memory or could fit within L1 cache
+ if (range < array.length || range < 1024) {
+ try {
+ acquireMemory((range + 1) * 8)
+ } catch {
+ case e: SparkException =>
+ // there is no enough memory to convert
+ return
+ }
+ val denseArray = new Array[Long]((range + 1).toInt)
+ var i = 0
+ while (i < array.length) {
+ if (array(i + 1) > 0) {
+ val idx = (array(i) - minKey).toInt
+ denseArray(idx) = array(i + 1)
+ }
+ i += 2
+ }
+ val old_length = array.length
+ array = denseArray
+ isDense = true
+ freeMemory(old_length * 8)
+ }
+ }
+
+ /**
+ * Free all the memory acquired by this map.
+ */
+ def free(): Unit = {
+ if (page != null) {
+ freeMemory(page.length)
+ page = null
+ }
+ if (array != null) {
+ freeMemory(array.length * 8)
+ array = null
}
}
override def writeExternal(out: ObjectOutput): Unit = {
- writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ out.writeBoolean(isDense)
+ out.writeLong(minKey)
+ out.writeLong(maxKey)
+ out.writeInt(numKeys)
+ out.writeInt(numValues)
+ out.writeInt(cap)
+
+ out.writeInt(array.length)
+ val buffer = new Array[Byte](4 << 10)
+ var offset = Platform.LONG_ARRAY_OFFSET
+ val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET
+ while (offset < end) {
+ val size = Math.min(buffer.length, end - offset)
+ Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
+ out.write(buffer, 0, size)
+ offset += size
+ }
+
+ val used = cursor - Platform.BYTE_ARRAY_OFFSET
+ out.writeInt(used)
+ out.write(page, 0, used)
}
override def readExternal(in: ObjectInput): Unit = {
- hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ isDense = in.readBoolean()
+ minKey = in.readLong()
+ maxKey = in.readLong()
+ numKeys = in.readInt()
+ numValues = in.readInt()
+ cap = in.readInt()
+
+ val length = in.readInt()
+ array = new Array[Long](length)
+ val buffer = new Array[Byte](4 << 10)
+ var offset = Platform.LONG_ARRAY_OFFSET
+ val end = length * 8 + Platform.LONG_ARRAY_OFFSET
+ while (offset < end) {
+ val size = Math.min(buffer.length, end - offset)
+ in.readFully(buffer, 0, size)
+ Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
+ offset += size
+ }
+
+ val numBytes = in.readInt()
+ page = new Array[Byte](numBytes)
+ in.readFully(page)
}
}
-/**
- * A relation that pack all the rows into a byte array, together with offsets and sizes.
- *
- * All the bytes of UnsafeRow are packed together as `bytes`:
- *
- * [ Row0 ][ Row1 ][] ... [ RowN ]
- *
- * With keys:
- *
- * start start+1 ... start+N
- *
- * `offsets` are offsets of UnsafeRows in the `bytes`
- * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key.
- *
- * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as:
- *
- * start = 3
- * offsets = [0, 0, 24]
- * sizes = [24, 0, 32]
- * bytes = [0 - 24][][24 - 56]
- */
-private[joins] final class LongArrayRelation(
- private var numFields: Int,
- private var start: Long,
- private var offsets: Array[Int],
- private var sizes: Array[Int],
- private var bytes: Array[Byte]
- ) extends LongHashedRelation with Externalizable {
+private[joins] class LongHashedRelation(
+ private var nFields: Int,
+ private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable {
+
+ private var resultRow: UnsafeRow = new UnsafeRow(nFields)
// Needed for serialization (it is public to make Java serialization work)
- def this() = this(0, 0L, null, null, null)
+ def this() = this(0, null)
- override def keyIsUnique: Boolean = true
+ override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map)
- override def asReadOnlyCopy(): LongArrayRelation = {
- new LongArrayRelation(numFields, start, offsets, sizes, bytes)
+ override def estimatedSize: Long = {
+ map.getTotalMemoryConsumption
}
- override def getMemorySize: Long = {
- offsets.length * 4 + sizes.length * 4 + bytes.length
+ override def get(key: InternalRow): Iterator[InternalRow] = {
+ if (key.isNullAt(0)) {
+ null
+ } else {
+ get(key.getLong(0))
+ }
}
- override def get(key: Long): Iterator[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- Seq(row).toIterator
- } else {
+ override def getValue(key: InternalRow): InternalRow = {
+ if (key.isNullAt(0)) {
null
+ } else {
+ getValue(key.getLong(0))
}
}
- var resultRow = new UnsafeRow(numFields)
+ override def get(key: Long): Iterator[InternalRow] =
+ map.get(key, resultRow)
+
override def getValue(key: Long): InternalRow = {
- val idx = (key - start).toInt
- if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
- resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
- resultRow
- } else {
- null
- }
+ map.getValue(key, resultRow)
+ }
+
+ override def keyIsUnique: Boolean = map.keyIsUnique
+
+ override def close(): Unit = {
+ map.free()
}
override def writeExternal(out: ObjectOutput): Unit = {
- out.writeInt(numFields)
- out.writeLong(start)
- out.writeInt(sizes.length)
- var i = 0
- while (i < sizes.length) {
- out.writeInt(sizes(i))
- i += 1
- }
- out.writeInt(bytes.length)
- out.write(bytes)
+ out.writeInt(nFields)
+ out.writeObject(map)
}
override def readExternal(in: ObjectInput): Unit = {
- numFields = in.readInt()
- resultRow = new UnsafeRow(numFields)
- start = in.readLong()
- val length = in.readInt()
- // read sizes of rows
- sizes = new Array[Int](length)
- offsets = new Array[Int](length)
- var i = 0
- var offset = 0
- while (i < length) {
- offsets(i) = offset
- sizes(i) = in.readInt()
- offset += sizes(i)
- i += 1
- }
- // read all the bytes
- val total = in.readInt()
- assert(total == offset)
- bytes = new Array[Byte](total)
- in.readFully(bytes)
+ nFields = in.readInt()
+ resultRow = new UnsafeRow(nFields)
+ map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
}
}
@@ -466,96 +763,45 @@ private[joins] final class LongArrayRelation(
* Create hashed relation with key that is long.
*/
private[joins] object LongHashedRelation {
-
- val DENSE_FACTOR = 0.2
-
def apply(
- input: Iterator[InternalRow],
- keyGenerator: Projection,
- sizeEstimate: Int): HashedRelation = {
+ input: Iterator[InternalRow],
+ key: Seq[Expression],
+ sizeEstimate: Int,
+ taskMemoryManager: TaskMemoryManager): LongHashedRelation = {
- // TODO: use LongToBytesMap for better memory efficiency
- val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
+ val keyGenerator = UnsafeProjection.create(key)
// Create a mapping of key -> rows
var numFields = 0
- var keyIsUnique = true
- var minKey = Long.MaxValue
- var maxKey = Long.MinValue
while (input.hasNext) {
val unsafeRow = input.next().asInstanceOf[UnsafeRow]
numFields = unsafeRow.numFields()
val rowKey = keyGenerator(unsafeRow)
- if (!rowKey.anyNull) {
+ if (!rowKey.isNullAt(0)) {
val key = rowKey.getLong(0)
- minKey = math.min(minKey, key)
- maxKey = math.max(maxKey, key)
- val existingMatchList = hashTable.get(key)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[UnsafeRow]()
- hashTable.put(key, newMatchList)
- newMatchList
- } else {
- keyIsUnique = false
- existingMatchList
- }
- matchList += unsafeRow
- }
- }
-
- if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
- // The keys are dense enough, so use LongArrayRelation
- val length = (maxKey - minKey).toInt + 1
- val sizes = new Array[Int](length)
- val offsets = new Array[Int](length)
- var offset = 0
- var i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- offsets(i) = offset
- sizes(i) = rows(0).getSizeInBytes
- offset += sizes(i)
- }
- i += 1
- }
- val bytes = new Array[Byte](offset)
- i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
- }
- i += 1
+ map.append(key, unsafeRow)
}
- new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
- } else {
- new GeneralLongHashedRelation(hashTable)
}
+ map.optimize()
+ new LongHashedRelation(numFields, map)
}
}
/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */
-private[execution] case class HashedRelationBroadcastMode(
- canJoinKeyFitWithinLong: Boolean,
- keys: Seq[Expression],
- attributes: Seq[Attribute]) extends BroadcastMode {
+private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression])
+ extends BroadcastMode {
override def transform(rows: Array[InternalRow]): HashedRelation = {
- val generator = UnsafeProjection.create(keys, attributes)
- HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length)
+ HashedRelation(rows.iterator, canonicalizedKey, rows.length)
}
- private lazy val canonicalizedKeys: Seq[Expression] = {
- keys.map { e =>
- BindReferences.bindReference(e.canonicalized, attributes)
- }
+ private lazy val canonicalizedKey: Seq[Expression] = {
+ key.map { e => e.canonicalized }
}
override def compatibleWith(other: BroadcastMode): Boolean = other match {
- case m: HashedRelationBroadcastMode =>
- canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong &&
- canonicalizedKeys == m.canonicalizedKeys
+ case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey
case _ => false
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index bf86096379..0c3e3c3fc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.{SparkException, TaskContext}
-import org.apache.spark.memory.MemoryMode
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -57,54 +56,20 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
- private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
+ private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val context = TaskContext.get()
- if (!canJoinKeyFitWithinLong) {
- // build BytesToBytesMap
- val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator)
- // This relation is usually used until the end of task.
- context.addTaskCompletionListener((t: TaskContext) =>
- relation.close()
- )
- return relation
- }
-
- // try to acquire some memory for the hash table, it could trigger other operator to free some
- // memory. The memory acquired here will mostly be used until the end of task.
- val memoryManager = context.taskMemoryManager()
- var acquired = 0L
- var used = 0L
+ val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
+ // This relation is usually used until the end of task.
context.addTaskCompletionListener((t: TaskContext) =>
- memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null)
+ relation.close()
)
-
- val copiedIter = iter.map { row =>
- // It's hard to guess what's exactly memory will be used, we have a rough guess here.
- // TODO: use LongToBytesMap instead of HashMap for memory efficiency
- // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers
- val needed = 150 + row.getSizeInBytes
- if (needed > acquired - used) {
- val got = memoryManager.acquireExecutionMemory(
- Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
- acquired += got
- if (got < needed) {
- throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
- "hash join, please use sort merge join by setting " +
- "spark.sql.join.preferSortMergeJoin=true")
- }
- }
- used += needed
- // HashedRelation requires that the UnsafeRow should be separate objects.
- row.copy()
- }
-
- HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator)
+ relation
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
- val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
+ val hashed = buildHashedRelation(buildIter)
join(streamIter, hashed, numOutputRows)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 5dbf619876..352fd07d0e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -21,6 +21,7 @@ import java.util.HashMap
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.vectorized.AggregateHashMap
@@ -179,8 +180,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X
- Join w long codegen=true 275 / 352 76.2 13.1 19.4X
+ Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X
+ Join w long codegen=true 321 / 371 65.3 15.3 9.3X
*/
runBenchmark("Join w long duplicated", N) {
@@ -193,8 +194,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X
- Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X
+ Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X
+ Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X
*/
val dim2 = broadcast(sqlContext.range(M)
@@ -211,8 +212,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X
- Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X
+ Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X
+ Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X
*/
val dim3 = broadcast(sqlContext.range(M)
@@ -259,8 +260,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X
- outer join w long codegen=true 216 / 226 97.2 10.3 26.3X
+ outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X
+ outer join w long codegen=true 261 / 276 80.5 12.4 11.7X
*/
runBenchmark("semi join w long", N) {
@@ -272,8 +273,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X
- semi join w long codegen=true 211 / 229 99.2 10.1 22.2X
+ semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X
+ semi join w long codegen=true 237 / 244 88.3 11.3 8.1X
*/
}
@@ -326,8 +327,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X
- shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X
+ shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X
+ shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X
*/
}
@@ -349,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("hash and BytesToBytesMap") {
- val N = 10 << 20
+ val N = 20 << 20
val benchmark = new Benchmark("BytesToBytesMap", N)
- benchmark.addCase("hash") { iter =>
+ benchmark.addCase("UnsafeRowhash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
@@ -368,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
+ benchmark.addCase("murmur3 hash") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var p = 524283
+ var s = 0
+ while (i < N) {
+ var h = Murmur3_x86_32.hashLong(i, 42)
+ key.setInt(0, h)
+ s += h
+ i += 1
+ }
+ }
+
benchmark.addCase("fast hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var p = 524283
var s = 0
while (i < N) {
- key.setInt(0, i % 1000)
- val h = Murmur3_x86_32.hashLong(i % 1000, 42)
+ var h = i % p
+ if (h < 0) {
+ h += p
+ }
+ key.setInt(0, h)
s += h
i += 1
}
@@ -475,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
+ Seq(false, true).foreach { optimized =>
+ benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter =>
+ var i = 0
+ val valueBytes = new Array[Byte](16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 64)
+ while (i < 65536) {
+ value.setInt(0, i)
+ val key = i % 100000
+ map.append(key, value)
+ i += 1
+ }
+ if (optimized) {
+ map.optimize()
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ val key = i % 100000
+ if (map.getValue(key, value) != null) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+ }
+
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
@@ -493,18 +549,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0
- while (i < N) {
+ val numKeys = 65536
+ while (i < numKeys) {
key.setInt(0, i % 65536)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
Murmur3_x86_32.hashLong(i % 65536, 42))
- if (loc.isDefined) {
- value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
- value.setInt(0, value.getInt(0) + 1)
- i += 1
- } else {
+ if (!loc.isDefined) {
loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
}
+ i += 1
+ }
+ i = 0
+ var s = 0
+ while (i < N) {
+ key.setInt(0, i % 100000)
+ val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ Murmur3_x86_32.hashLong(i % 100000, 42))
+ if (loc.isDefined) {
+ s += 1
+ }
+ i += 1
}
}
}
@@ -535,16 +600,19 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- hash 112 / 116 93.2 10.7 1.0X
- fast hash 65 / 69 160.9 6.2 1.7X
- arrayEqual 66 / 69 159.1 6.3 1.7X
- Java HashMap (Long) 137 / 182 76.3 13.1 0.8X
- Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X
- Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X
- BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X
- BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X
- Aggregate HashMap 56 / 62 187.9 5.3 2.0X
- */
+ UnsafeRow hash 267 / 284 78.4 12.8 1.0X
+ murmur3 hash 102 / 129 205.5 4.9 2.6X
+ fast hash 79 / 96 263.8 3.8 3.4X
+ arrayEqual 164 / 172 128.2 7.8 1.6X
+ Java HashMap (Long) 321 / 399 65.4 15.3 0.8X
+ Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X
+ Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X
+ LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X
+ LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X
+ BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X
+ BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X
+ Aggregate HashMap 121 / 131 173.3 5.8 2.2X
+ */
benchmark.run()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 9680f3a008..17f2343cf9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
test("compatible BroadcastMode") {
val mode1 = IdentityBroadcastMode
- val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq())
- val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq())
+ val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
+ val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)
assert(mode1.compatibleWith(mode1))
assert(!mode1.compatibleWith(mode2))
@@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
assert(plan sameResult plan)
val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan)
- val hashMode = HashedRelationBroadcastMode(true, output, plan.output)
+ val hashMode = HashedRelationBroadcastMode(output)
val exchange2 = BroadcastExchange(hashMode, plan)
val hashMode2 =
- HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output)
+ HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil)
val exchange3 = BroadcastExchange(hashMode2, plan)
val exchange4 = ReusedExchange(output, exchange3)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index ed87a99439..371a9ed617 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -30,15 +30,23 @@ import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
+ val mm = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy())
+
val buildKey = Seq(BoundReference(0, IntegerType, false))
- val keyGenerator = UnsafeProjection.create(buildKey)
- val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
+ val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm)
assert(hashed.isInstanceOf[UnsafeHashedRelation])
assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0)))
@@ -100,31 +108,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
- test("LongArrayRelation") {
+ test("LongToUnsafeRowMap") {
val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
- val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false)))
- val longRelation = LongHashedRelation(rows.iterator, keyProj, 100)
- assert(longRelation.isInstanceOf[LongArrayRelation])
- val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation]
+ val key = Seq(BoundReference(0, IntegerType, false))
+ val longRelation = LongHashedRelation(rows.iterator, key, 10, mm)
+ assert(longRelation.keyIsUnique)
(0 until 100).foreach { i =>
- val row = longArrayRelation.getValue(i)
+ val row = longRelation.getValue(i)
assert(row.getInt(0) === i)
assert(row.getInt(1) === i + 1)
}
+ val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
+ assert(!longRelation2.keyIsUnique)
+ (0 until 100).foreach { i =>
+ val rows = longRelation2.get(i).toArray
+ assert(rows.length === 2)
+ assert(rows(0).getInt(0) === i)
+ assert(rows(0).getInt(1) === i + 1)
+ assert(rows(1).getInt(0) === i)
+ assert(rows(1).getInt(1) === i + 1)
+ }
+
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
- longArrayRelation.writeExternal(out)
+ longRelation2.writeExternal(out)
out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
- val relation = new LongArrayRelation()
+ val relation = new LongHashedRelation()
relation.readExternal(in)
+ assert(!relation.keyIsUnique)
(0 until 100).foreach { i =>
- val row = longArrayRelation.getValue(i)
- assert(row.getInt(0) === i)
- assert(row.getInt(1) === i + 1)
+ val rows = relation.get(i).toArray
+ assert(rows.length === 2)
+ assert(rows(0).getInt(0) === i)
+ assert(rows(0).getInt(1) === i + 1)
+ assert(rows(1).getInt(0) === i)
+ assert(rows(1).getInt(1) === i + 1)
}
}
}