aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2016-04-09 13:51:28 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-09 13:51:28 -0700
commitf7ec854f1b7f575c4c7437daf8e6992c684b6de2 (patch)
treebe06f3cc7a743683150b4eab9bb552d5f0350da8 /sql/core
parentadb9d73cd6543c9edfc6b03a6d20061ff09c69f9 (diff)
downloadspark-f7ec854f1b7f575c4c7437daf8e6992c684b6de2.tar.gz
spark-f7ec854f1b7f575c4c7437daf8e6992c684b6de2.tar.bz2
spark-f7ec854f1b7f575c4c7437daf8e6992c684b6de2.zip
Revert "[SPARK-14419] [SQL] Improve HashedRelation for key fit within Long"
This reverts commit 90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala688
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala51
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala132
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala48
8 files changed, 346 insertions, 633 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 692fef703f..0a5a72c52a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -454,7 +454,7 @@ case class TungstenAggregate(
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
- ctx.addMutableState(hashMapClassName, hashMapTerm, s"")
+ ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
@@ -467,7 +467,6 @@ case class TungstenAggregate(
s"""
${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
private void $doAgg() throws java.io.IOException {
- $hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index a8f854136c..e3d554c2de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.LongType
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
@@ -51,7 +50,10 @@ case class BroadcastHashJoin(
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
override def requiredChildDistribution: Seq[Distribution] = {
- val mode = HashedRelationBroadcastMode(buildKeys)
+ val mode = HashedRelationBroadcastMode(
+ canJoinKeyFitWithinLong,
+ rewriteKeyExpr(buildKeys),
+ buildPlan.output)
buildSide match {
case BuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
@@ -66,7 +68,7 @@ case class BroadcastHashJoin(
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
val hashed = broadcastRelation.value.asReadOnlyCopy()
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize)
join(streamedIter, hashed, numOutputRows)
}
}
@@ -103,7 +105,7 @@ case class BroadcastHashJoin(
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
- | incPeakExecutionMemory($relationTerm.estimatedSize());
+ | incPeakExecutionMemory($relationTerm.getMemorySize());
""".stripMargin)
(broadcastRelation, relationTerm)
}
@@ -116,13 +118,15 @@ case class BroadcastHashJoin(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
- if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
+ if (canJoinKeyFitWithinLong) {
// generate the join key as Long
- val ev = streamedKeys.head.gen(ctx)
+ val expr = rewriteKeyExpr(streamedKeys).head
+ val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
- val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
+ val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
+ val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
(ev, s"${ev.value}.anyNull()")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 4c912d371e..8f45d57126 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -59,13 +59,9 @@ trait HashJoin {
case BuildRight => (right, left)
}
- protected lazy val (buildKeys, streamedKeys) = {
- val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
- val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output))
- buildSide match {
- case BuildLeft => (lkeys, rkeys)
- case BuildRight => (rkeys, lkeys)
- }
+ protected lazy val (buildKeys, streamedKeys) = buildSide match {
+ case BuildLeft => (leftKeys, rightKeys)
+ case BuildRight => (rightKeys, leftKeys)
}
/**
@@ -88,8 +84,17 @@ trait HashJoin {
width = dt.defaultSize
} else {
val bits = dt.defaultSize * 8
+ // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
+ // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
+ // with two same ints have hash code 0, we rotate the bits of second one.
+ val rotated = if (e.dataType == IntegerType) {
+ // (e >>> 15) | (e << 17)
+ BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
+ } else {
+ e
+ }
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
- BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+ BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1)))
width -= bits
}
// TODO: support BooleanType, DateType and TimestampType
@@ -100,11 +105,17 @@ trait HashJoin {
keyExpr :: Nil
}
+ protected lazy val canJoinKeyFitWithinLong: Boolean = {
+ val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
+ val key = rewriteKeyExpr(buildKeys)
+ sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
+ }
+
protected def buildSideKeyGenerator(): Projection =
- UnsafeProjection.create(buildKeys)
+ UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
protected def streamSideKeyGenerator(): UnsafeProjection =
- UnsafeProjection.create(streamedKeys)
+ UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 4959f60dab..5ccb435686 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,22 +18,24 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.{SparkConf, SparkEnv, SparkException}
-import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
-import org.apache.spark.sql.types.LongType
+import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.{KnownSizeEstimation, Utils}
+import org.apache.spark.util.collection.CompactBuffer
/**
* Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
* object.
*/
-private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
+private[execution] sealed trait HashedRelation {
/**
* Returns matched rows.
*
@@ -73,35 +75,50 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
def asReadOnlyCopy(): HashedRelation
/**
+ * Returns the size of used memory.
+ */
+ def getMemorySize: Long = 1L // to make the test happy
+
+ /**
* Release any used resources.
*/
- def close(): Unit
+ def close(): Unit = {}
+
+ // This is a helper method to implement Externalizable, and is used by
+ // GeneralHashedRelation and UniqueKeyHashedRelation
+ protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
+ out.writeInt(serialized.length) // Write the length of serialized bytes first
+ out.write(serialized)
+ }
+
+ // This is a helper method to implement Externalizable, and is used by
+ // GeneralHashedRelation and UniqueKeyHashedRelation
+ protected def readBytes(in: ObjectInput): Array[Byte] = {
+ val serializedSize = in.readInt() // Read the length of serialized bytes first
+ val bytes = new Array[Byte](serializedSize)
+ in.readFully(bytes)
+ bytes
+ }
}
private[execution] object HashedRelation {
/**
* Create a HashedRelation from an Iterator of InternalRow.
+ *
+ * Note: The caller should make sure that these InternalRow are different objects.
*/
def apply(
+ canJoinKeyFitWithinLong: Boolean,
input: Iterator[InternalRow],
- key: Seq[Expression],
- sizeEstimate: Int = 64,
- taskMemoryManager: TaskMemoryManager = null): HashedRelation = {
- val mm = Option(taskMemoryManager).getOrElse {
- new TaskMemoryManager(
- new StaticMemoryManager(
- new SparkConf().set("spark.memory.offHeap.enabled", "false"),
- Long.MaxValue,
- Long.MaxValue,
- 1),
- 0)
- }
+ keyGenerator: Projection,
+ sizeEstimate: Int = 64): HashedRelation = {
- if (key.length == 1 && key.head.dataType == LongType) {
- LongHashedRelation(input, key, sizeEstimate, mm)
+ if (canJoinKeyFitWithinLong) {
+ LongHashedRelation(input, keyGenerator, sizeEstimate)
} else {
- UnsafeHashedRelation(input, key, sizeEstimate, mm)
+ UnsafeHashedRelation(
+ input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}
}
}
@@ -116,7 +133,7 @@ private[execution] object HashedRelation {
private[joins] class UnsafeHashedRelation(
private var numFields: Int,
private var binaryMap: BytesToBytesMap)
- extends HashedRelation with Externalizable {
+ extends HashedRelation with KnownSizeEstimation with Externalizable {
private[joins] def this() = this(0, null) // Needed for serialization
@@ -125,6 +142,10 @@ private[joins] class UnsafeHashedRelation(
override def asReadOnlyCopy(): UnsafeHashedRelation =
new UnsafeHashedRelation(numFields, binaryMap)
+ override def getMemorySize: Long = {
+ binaryMap.getTotalMemoryConsumption
+ }
+
override def estimatedSize: Long = {
binaryMap.getTotalMemoryConsumption
}
@@ -255,10 +276,20 @@ private[joins] object UnsafeHashedRelation {
def apply(
input: Iterator[InternalRow],
- key: Seq[Expression],
- sizeEstimate: Int,
- taskMemoryManager: TaskMemoryManager): HashedRelation = {
+ keyGenerator: UnsafeProjection,
+ sizeEstimate: Int): HashedRelation = {
+ val taskMemoryManager = if (TaskContext.get() != null) {
+ TaskContext.get().taskMemoryManager()
+ } else {
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ }
val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
.getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
@@ -269,7 +300,6 @@ private[joins] object UnsafeHashedRelation {
pageSizeBytes)
// Create a mapping of buildKeys -> rows
- val keyGenerator = UnsafeProjection.create(key)
var numFields = 0
while (input.hasNext) {
val row = input.next().asInstanceOf[UnsafeRow]
@@ -291,471 +321,144 @@ private[joins] object UnsafeHashedRelation {
}
}
-private[joins] object LongToUnsafeRowMap {
- // the largest prime that below 2^n
- val LARGEST_PRIMES = {
- // https://primes.utm.edu/lists/2small/0bit.html
- val diffs = Seq(
- 0, 1, 1, 3, 1, 3, 1, 5,
- 3, 3, 9, 3, 1, 3, 19, 15,
- 1, 5, 1, 3, 9, 3, 15, 3,
- 39, 5, 39, 57, 3, 35, 1, 5
- )
- val primes = new Array[Int](32)
- primes(0) = 1
- var power2 = 1
- (1 until 32).foreach { i =>
- power2 *= 2
- primes(i) = power2 - diffs(i)
- }
- primes
- }
-}
-
/**
- * An append-only hash map mapping from key of Long to UnsafeRow.
- *
- * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array
- * (`page`) in this format:
- *
- * [bytes of row1][address1][bytes of row2][address1] ...
- *
- * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key
- * could have multiple values. the address at the end of last value for every key is 0.
- *
- * The keys and addresses of their values could be stored in two modes:
- *
- * 1) sparse mode: the keys and addresses are stored in `array` as:
- *
- * [key1][address1][key2][address2]...[]
- *
- * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1
- * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address
- * hash collision.
- *
- * 2) dense mode: all the addresses are packed into a single array of long, as:
- *
- * [address1] [address2] ...
- *
- * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is
- * determined by `key1 - minKey`.
- *
- * The map is created as sparse mode, then key-value could be appended into it. Once finish
- * appending, caller could all optimize() to try to turn the map into dense mode, which is faster
- * to probe.
+ * An interface for a hashed relation that the key is a Long.
*/
-private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int)
- extends MemoryConsumer(mm) with Externalizable {
- import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._
-
- // Whether the keys are stored in dense mode or not.
- private var isDense = false
-
- // The minimum value of keys.
- private var minKey = Long.MaxValue
-
- // The Maxinum value of keys.
- private var maxKey = Long.MinValue
-
- // Sparse mode: the actual capacity of map, is a prime number.
- private var cap: Int = 0
-
- // The array to store the key and offset of UnsafeRow in the page.
- //
- // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ...
- // Dense mode: [offset1 | size1] [offset2 | size2]
- private var array: Array[Long] = null
-
- // The page to store all bytes of UnsafeRow and the pointer to next rows.
- // [row1][pointer1] [row2][pointer2]
- private var page: Array[Byte] = null
-
- // Current write cursor in the page.
- private var cursor = Platform.BYTE_ARRAY_OFFSET
-
- // The total number of values of all keys.
- private var numValues = 0
-
- // The number of unique keys.
- private var numKeys = 0
-
- // needed by serializer
- def this() = {
- this(
- new TaskMemoryManager(
- new StaticMemoryManager(
- new SparkConf().set("spark.memory.offHeap.enabled", "false"),
- Long.MaxValue,
- Long.MaxValue,
- 1),
- 0),
- 0)
- }
-
- private def acquireMemory(size: Long): Unit = {
- // do not support spilling
- val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this)
- if (got < size) {
- mm.releaseExecutionMemory(got, MemoryMode.ON_HEAP, this)
- throw new SparkException(s"Can't acquire $size bytes memory to build hash relation")
- }
- }
-
- private def freeMemory(size: Long): Unit = {
- mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this)
- }
-
- private def init(): Unit = {
- if (mm != null) {
- cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{
- sys.error(s"Can't create map with capacity $capacity")
- }
- acquireMemory(cap * 2 * 8 + (1 << 20))
- array = new Array[Long](cap * 2)
- page = new Array[Byte](1 << 20) // 1M bytes
- }
- }
-
- init()
-
- def spill(size: Long, trigger: MemoryConsumer): Long = {
- 0L
- }
-
- /**
- * Returns whether all the keys are unique.
- */
- def keyIsUnique: Boolean = numKeys == numValues
-
- /**
- * Returns total memory consumption.
- */
- def getTotalMemoryConsumption: Long = {
- array.length * 8 + page.length
- }
-
- /**
- * Returns the slot of array that store the keys (sparse mode).
- */
- private def getSlot(key: Long): Int = {
- var s = (key % cap).toInt
- if (s < 0) {
- s += cap
- }
- s * 2
- }
-
- private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
- val offset = address >>> 32
- val size = address & 0xffffffffL
- resultRow.pointTo(page, offset, size.toInt)
- resultRow
+private[joins] trait LongHashedRelation extends HashedRelation {
+ override def get(key: InternalRow): Iterator[InternalRow] = {
+ get(key.getLong(0))
}
-
- /**
- * Returns the single UnsafeRow for given key, or null if not found.
- */
- def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
- if (isDense) {
- val idx = (key - minKey).toInt
- if (idx >= 0 && key <= maxKey && array(idx) > 0) {
- return getRow(array(idx), resultRow)
- }
- } else {
- var pos = getSlot(key)
- var step = 1
- while (array(pos + 1) != 0) {
- if (array(pos) == key) {
- return getRow(array(pos + 1), resultRow)
- }
- pos += 2 * step
- step += 1
- if (pos >= array.length) {
- pos -= array.length
- }
- }
- }
- null
+ override def getValue(key: InternalRow): InternalRow = {
+ getValue(key.getLong(0))
}
+}
- /**
- * Returns an interator of UnsafeRow for multiple linked values.
- */
- private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
- new Iterator[UnsafeRow] {
- var addr = address
- override def hasNext: Boolean = addr != 0
- override def next(): UnsafeRow = {
- val offset = addr >>> 32
- val size = addr & 0xffffffffL
- resultRow.pointTo(page, offset, size.toInt)
- addr = Platform.getLong(page, offset + size)
- resultRow
- }
- }
- }
+private[joins] final class GeneralLongHashedRelation(
+ private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]])
+ extends LongHashedRelation with Externalizable {
- /**
- * Returns an iterator for all the values for the given key, or null if no value found.
- */
- def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
- if (isDense) {
- val idx = (key - minKey).toInt
- if (idx >=0 && key <= maxKey && array(idx) > 0) {
- return valueIter(array(idx), resultRow)
- }
- } else {
- var pos = getSlot(key)
- var step = 1
- while (array(pos + 1) != 0) {
- if (array(pos) == key) {
- return valueIter(array(pos + 1), resultRow)
- }
- pos += 2 * step
- step += 1
- if (pos >= array.length) {
- pos -= array.length
- }
- }
- }
- null
- }
-
- /**
- * Appends the key and row into this map.
- */
- def append(key: Long, row: UnsafeRow): Unit = {
- if (key < minKey) {
- minKey = key
- }
- if (key > maxKey) {
- maxKey = key
- }
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(null)
- // There is 8 bytes for the pointer to next value
- if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) {
- val used = page.length
- if (used * 2L > (1L << 31)) {
- sys.error("Can't allocate a page that is larger than 2G")
- }
- acquireMemory(used * 2)
- val newPage = new Array[Byte](used * 2)
- System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET)
- page = newPage
- freeMemory(used)
- }
+ override def keyIsUnique: Boolean = false
- // copy the bytes of UnsafeRow
- val offset = cursor
- Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes)
- cursor += row.getSizeInBytes
- Platform.putLong(page, cursor, 0)
- cursor += 8
- numValues += 1
- updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes)
- }
+ override def asReadOnlyCopy(): GeneralLongHashedRelation =
+ new GeneralLongHashedRelation(hashTable)
- /**
- * Update the address in array for given key.
- */
- private def updateIndex(key: Long, address: Long): Unit = {
- var pos = getSlot(key)
- var step = 1
- while (array(pos + 1) != 0 && array(pos) != key) {
- pos += 2 * step
- step += 1
- if (pos >= array.length) {
- pos -= array.length
- }
- }
- if (array(pos + 1) == 0) {
- // this is the first value for this key, put the address in array.
- array(pos) = key
- array(pos + 1) = address
- numKeys += 1
- if (numKeys * 2 > cap) {
- // reach half of the capacity
- growArray()
- }
+ override def get(key: Long): Iterator[InternalRow] = {
+ val rows = hashTable.get(key)
+ if (rows != null) {
+ rows.toIterator
} else {
- // there is another value for this key, put the address at the end of final value.
- var addr = array(pos + 1)
- var pointer = (addr >>> 32) + (addr & 0xffffffffL)
- while (Platform.getLong(page, pointer) != 0) {
- addr = Platform.getLong(page, pointer)
- pointer = (addr >>> 32) + (addr & 0xffffffffL)
- }
- Platform.putLong(page, pointer, address)
- }
- }
-
- private def growArray(): Unit = {
- val old_cap = cap
- var old_array = array
- cap = LARGEST_PRIMES.find(_ > cap).getOrElse{
- sys.error(s"Can't grow map any more than $cap")
- }
- numKeys = 0
- acquireMemory(cap * 2 * 8)
- array = new Array[Long](cap * 2)
- var i = 0
- while (i < old_array.length) {
- if (old_array(i + 1) > 0) {
- updateIndex(old_array(i), old_array(i + 1))
- }
- i += 2
- }
- old_array = null // release the reference to old array
- freeMemory(old_cap * 2 * 8)
- }
-
- /**
- * Try to turn the map into dense mode, which is faster to probe.
- */
- def optimize(): Unit = {
- val range = maxKey - minKey
- // Convert to dense mode if it does not require more memory or could fit within L1 cache
- if (range < array.length || range < 1024) {
- try {
- acquireMemory((range + 1) * 8)
- } catch {
- case e: SparkException =>
- // there is no enough memory to convert
- return
- }
- val denseArray = new Array[Long]((range + 1).toInt)
- var i = 0
- while (i < array.length) {
- if (array(i + 1) > 0) {
- val idx = (array(i) - minKey).toInt
- denseArray(idx) = array(i + 1)
- }
- i += 2
- }
- val old_length = array.length
- array = denseArray
- isDense = true
- freeMemory(old_length * 8)
- }
- }
-
- /**
- * Free all the memory acquired by this map.
- */
- def free(): Unit = {
- if (page != null) {
- freeMemory(page.length)
- page = null
- }
- if (array != null) {
- freeMemory(array.length * 8)
- array = null
+ null
}
}
override def writeExternal(out: ObjectOutput): Unit = {
- out.writeBoolean(isDense)
- out.writeLong(minKey)
- out.writeLong(maxKey)
- out.writeInt(numKeys)
- out.writeInt(numValues)
- out.writeInt(cap)
-
- out.writeInt(array.length)
- val buffer = new Array[Byte](4 << 10)
- var offset = Platform.LONG_ARRAY_OFFSET
- val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET
- while (offset < end) {
- val size = Math.min(buffer.length, end - offset)
- Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
- out.write(buffer, 0, size)
- offset += size
- }
-
- val used = cursor - Platform.BYTE_ARRAY_OFFSET
- out.writeInt(used)
- out.write(page, 0, used)
+ writeBytes(out, SparkSqlSerializer.serialize(hashTable))
}
override def readExternal(in: ObjectInput): Unit = {
- isDense = in.readBoolean()
- minKey = in.readLong()
- maxKey = in.readLong()
- numKeys = in.readInt()
- numValues = in.readInt()
- cap = in.readInt()
-
- val length = in.readInt()
- array = new Array[Long](length)
- val buffer = new Array[Byte](4 << 10)
- var offset = Platform.LONG_ARRAY_OFFSET
- val end = length * 8 + Platform.LONG_ARRAY_OFFSET
- while (offset < end) {
- val size = Math.min(buffer.length, end - offset)
- in.readFully(buffer, 0, size)
- Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
- offset += size
- }
-
- val numBytes = in.readInt()
- page = new Array[Byte](numBytes)
- in.readFully(page)
+ hashTable = SparkSqlSerializer.deserialize(readBytes(in))
}
}
-private[joins] class LongHashedRelation(
- private var nFields: Int,
- private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable {
-
- private var resultRow: UnsafeRow = new UnsafeRow(nFields)
+/**
+ * A relation that pack all the rows into a byte array, together with offsets and sizes.
+ *
+ * All the bytes of UnsafeRow are packed together as `bytes`:
+ *
+ * [ Row0 ][ Row1 ][] ... [ RowN ]
+ *
+ * With keys:
+ *
+ * start start+1 ... start+N
+ *
+ * `offsets` are offsets of UnsafeRows in the `bytes`
+ * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key.
+ *
+ * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as:
+ *
+ * start = 3
+ * offsets = [0, 0, 24]
+ * sizes = [24, 0, 32]
+ * bytes = [0 - 24][][24 - 56]
+ */
+private[joins] final class LongArrayRelation(
+ private var numFields: Int,
+ private var start: Long,
+ private var offsets: Array[Int],
+ private var sizes: Array[Int],
+ private var bytes: Array[Byte]
+ ) extends LongHashedRelation with Externalizable {
// Needed for serialization (it is public to make Java serialization work)
- def this() = this(0, null)
+ def this() = this(0, 0L, null, null, null)
- override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map)
+ override def keyIsUnique: Boolean = true
- override def estimatedSize: Long = {
- map.getTotalMemoryConsumption
+ override def asReadOnlyCopy(): LongArrayRelation = {
+ new LongArrayRelation(numFields, start, offsets, sizes, bytes)
}
- override def get(key: InternalRow): Iterator[InternalRow] = {
- if (key.isNullAt(0)) {
- null
- } else {
- get(key.getLong(0))
- }
+ override def getMemorySize: Long = {
+ offsets.length * 4 + sizes.length * 4 + bytes.length
}
- override def getValue(key: InternalRow): InternalRow = {
- if (key.isNullAt(0)) {
- null
+ override def get(key: Long): Iterator[InternalRow] = {
+ val row = getValue(key)
+ if (row != null) {
+ Seq(row).toIterator
} else {
- getValue(key.getLong(0))
+ null
}
}
- override def get(key: Long): Iterator[InternalRow] =
- map.get(key, resultRow)
-
+ var resultRow = new UnsafeRow(numFields)
override def getValue(key: Long): InternalRow = {
- map.getValue(key, resultRow)
- }
-
- override def keyIsUnique: Boolean = map.keyIsUnique
-
- override def close(): Unit = {
- map.free()
+ val idx = (key - start).toInt
+ if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
+ resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
+ resultRow
+ } else {
+ null
+ }
}
override def writeExternal(out: ObjectOutput): Unit = {
- out.writeInt(nFields)
- out.writeObject(map)
+ out.writeInt(numFields)
+ out.writeLong(start)
+ out.writeInt(sizes.length)
+ var i = 0
+ while (i < sizes.length) {
+ out.writeInt(sizes(i))
+ i += 1
+ }
+ out.writeInt(bytes.length)
+ out.write(bytes)
}
override def readExternal(in: ObjectInput): Unit = {
- nFields = in.readInt()
- resultRow = new UnsafeRow(nFields)
- map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
+ numFields = in.readInt()
+ resultRow = new UnsafeRow(numFields)
+ start = in.readLong()
+ val length = in.readInt()
+ // read sizes of rows
+ sizes = new Array[Int](length)
+ offsets = new Array[Int](length)
+ var i = 0
+ var offset = 0
+ while (i < length) {
+ offsets(i) = offset
+ sizes(i) = in.readInt()
+ offset += sizes(i)
+ i += 1
+ }
+ // read all the bytes
+ val total = in.readInt()
+ assert(total == offset)
+ bytes = new Array[Byte](total)
+ in.readFully(bytes)
}
}
@@ -763,45 +466,96 @@ private[joins] class LongHashedRelation(
* Create hashed relation with key that is long.
*/
private[joins] object LongHashedRelation {
+
+ val DENSE_FACTOR = 0.2
+
def apply(
- input: Iterator[InternalRow],
- key: Seq[Expression],
- sizeEstimate: Int,
- taskMemoryManager: TaskMemoryManager): LongHashedRelation = {
+ input: Iterator[InternalRow],
+ keyGenerator: Projection,
+ sizeEstimate: Int): HashedRelation = {
- val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
- val keyGenerator = UnsafeProjection.create(key)
+ // TODO: use LongToBytesMap for better memory efficiency
+ val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
// Create a mapping of key -> rows
var numFields = 0
+ var keyIsUnique = true
+ var minKey = Long.MaxValue
+ var maxKey = Long.MinValue
while (input.hasNext) {
val unsafeRow = input.next().asInstanceOf[UnsafeRow]
numFields = unsafeRow.numFields()
val rowKey = keyGenerator(unsafeRow)
- if (!rowKey.isNullAt(0)) {
+ if (!rowKey.anyNull) {
val key = rowKey.getLong(0)
- map.append(key, unsafeRow)
+ minKey = math.min(minKey, key)
+ maxKey = math.max(maxKey, key)
+ val existingMatchList = hashTable.get(key)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new CompactBuffer[UnsafeRow]()
+ hashTable.put(key, newMatchList)
+ newMatchList
+ } else {
+ keyIsUnique = false
+ existingMatchList
+ }
+ matchList += unsafeRow
+ }
+ }
+
+ if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
+ // The keys are dense enough, so use LongArrayRelation
+ val length = (maxKey - minKey).toInt + 1
+ val sizes = new Array[Int](length)
+ val offsets = new Array[Int](length)
+ var offset = 0
+ var i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ offsets(i) = offset
+ sizes(i) = rows(0).getSizeInBytes
+ offset += sizes(i)
+ }
+ i += 1
+ }
+ val bytes = new Array[Byte](offset)
+ i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
+ }
+ i += 1
}
+ new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
+ } else {
+ new GeneralLongHashedRelation(hashTable)
}
- map.optimize()
- new LongHashedRelation(numFields, map)
}
}
/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */
-private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression])
- extends BroadcastMode {
+private[execution] case class HashedRelationBroadcastMode(
+ canJoinKeyFitWithinLong: Boolean,
+ keys: Seq[Expression],
+ attributes: Seq[Attribute]) extends BroadcastMode {
override def transform(rows: Array[InternalRow]): HashedRelation = {
- HashedRelation(rows.iterator, canonicalizedKey, rows.length)
+ val generator = UnsafeProjection.create(keys, attributes)
+ HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length)
}
- private lazy val canonicalizedKey: Seq[Expression] = {
- key.map { e => e.canonicalized }
+ private lazy val canonicalizedKeys: Seq[Expression] = {
+ keys.map { e =>
+ BindReferences.bindReference(e.canonicalized, attributes)
+ }
}
override def compatibleWith(other: BroadcastMode): Boolean = other match {
- case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey
+ case m: HashedRelationBroadcastMode =>
+ canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong &&
+ canonicalizedKeys == m.canonicalizedKeys
case _ => false
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 0c3e3c3fc1..bf86096379 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.memory.MemoryMode
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -56,20 +57,54 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
- private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
+ private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
val context = TaskContext.get()
- val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
- // This relation is usually used until the end of task.
+ if (!canJoinKeyFitWithinLong) {
+ // build BytesToBytesMap
+ val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator)
+ // This relation is usually used until the end of task.
+ context.addTaskCompletionListener((t: TaskContext) =>
+ relation.close()
+ )
+ return relation
+ }
+
+ // try to acquire some memory for the hash table, it could trigger other operator to free some
+ // memory. The memory acquired here will mostly be used until the end of task.
+ val memoryManager = context.taskMemoryManager()
+ var acquired = 0L
+ var used = 0L
context.addTaskCompletionListener((t: TaskContext) =>
- relation.close()
+ memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null)
)
- relation
+
+ val copiedIter = iter.map { row =>
+ // It's hard to guess what's exactly memory will be used, we have a rough guess here.
+ // TODO: use LongToBytesMap instead of HashMap for memory efficiency
+ // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers
+ val needed = 150 + row.getSizeInBytes
+ if (needed > acquired - used) {
+ val got = memoryManager.acquireExecutionMemory(
+ Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
+ acquired += got
+ if (got < needed) {
+ throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
+ "hash join, please use sort merge join by setting " +
+ "spark.sql.join.preferSortMergeJoin=true")
+ }
+ }
+ used += needed
+ // HashedRelation requires that the UnsafeRow should be separate objects.
+ row.copy()
+ }
+
+ HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator)
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
- val hashed = buildHashedRelation(buildIter)
+ val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
join(streamIter, hashed, numOutputRows)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 352fd07d0e..5dbf619876 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -21,7 +21,6 @@ import java.util.HashMap
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
-import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.vectorized.AggregateHashMap
@@ -180,8 +179,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X
- Join w long codegen=true 321 / 371 65.3 15.3 9.3X
+ Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X
+ Join w long codegen=true 275 / 352 76.2 13.1 19.4X
*/
runBenchmark("Join w long duplicated", N) {
@@ -194,8 +193,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X
- Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X
+ Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X
+ Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X
*/
val dim2 = broadcast(sqlContext.range(M)
@@ -212,8 +211,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X
- Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X
+ Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X
+ Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X
*/
val dim3 = broadcast(sqlContext.range(M)
@@ -260,8 +259,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X
- outer join w long codegen=true 261 / 276 80.5 12.4 11.7X
+ outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X
+ outer join w long codegen=true 216 / 226 97.2 10.3 26.3X
*/
runBenchmark("semi join w long", N) {
@@ -273,8 +272,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X
- semi join w long codegen=true 237 / 244 88.3 11.3 8.1X
+ semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X
+ semi join w long codegen=true 211 / 229 99.2 10.1 22.2X
*/
}
@@ -327,8 +326,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X
- shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X
+ shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X
+ shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X
*/
}
@@ -350,11 +349,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("hash and BytesToBytesMap") {
- val N = 20 << 20
+ val N = 10 << 20
val benchmark = new Benchmark("BytesToBytesMap", N)
- benchmark.addCase("UnsafeRowhash") { iter =>
+ benchmark.addCase("hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
@@ -369,34 +368,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
- benchmark.addCase("murmur3 hash") { iter =>
- var i = 0
- val keyBytes = new Array[Byte](16)
- val key = new UnsafeRow(1)
- key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
- var p = 524283
- var s = 0
- while (i < N) {
- var h = Murmur3_x86_32.hashLong(i, 42)
- key.setInt(0, h)
- s += h
- i += 1
- }
- }
-
benchmark.addCase("fast hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
- var p = 524283
var s = 0
while (i < N) {
- var h = i % p
- if (h < 0) {
- h += p
- }
- key.setInt(0, h)
+ key.setInt(0, i % 1000)
+ val h = Murmur3_x86_32.hashLong(i % 1000, 42)
s += h
i += 1
}
@@ -495,42 +475,6 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
- Seq(false, true).foreach { optimized =>
- benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter =>
- var i = 0
- val valueBytes = new Array[Byte](16)
- val value = new UnsafeRow(1)
- value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
- value.setInt(0, 555)
- val taskMemoryManager = new TaskMemoryManager(
- new StaticMemoryManager(
- new SparkConf().set("spark.memory.offHeap.enabled", "false"),
- Long.MaxValue,
- Long.MaxValue,
- 1),
- 0)
- val map = new LongToUnsafeRowMap(taskMemoryManager, 64)
- while (i < 65536) {
- value.setInt(0, i)
- val key = i % 100000
- map.append(key, value)
- i += 1
- }
- if (optimized) {
- map.optimize()
- }
- var s = 0
- i = 0
- while (i < N) {
- val key = i % 100000
- if (map.getValue(key, value) != null) {
- s += 1
- }
- i += 1
- }
- }
- }
-
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
@@ -549,27 +493,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0
- val numKeys = 65536
- while (i < numKeys) {
+ while (i < N) {
key.setInt(0, i % 65536)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
Murmur3_x86_32.hashLong(i % 65536, 42))
- if (!loc.isDefined) {
+ if (loc.isDefined) {
+ value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ value.setInt(0, value.getInt(0) + 1)
+ i += 1
+ } else {
loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
}
- i += 1
- }
- i = 0
- var s = 0
- while (i < N) {
- key.setInt(0, i % 100000)
- val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
- Murmur3_x86_32.hashLong(i % 100000, 42))
- if (loc.isDefined) {
- s += 1
- }
- i += 1
}
}
}
@@ -600,19 +535,16 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- UnsafeRow hash 267 / 284 78.4 12.8 1.0X
- murmur3 hash 102 / 129 205.5 4.9 2.6X
- fast hash 79 / 96 263.8 3.8 3.4X
- arrayEqual 164 / 172 128.2 7.8 1.6X
- Java HashMap (Long) 321 / 399 65.4 15.3 0.8X
- Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X
- Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X
- LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X
- LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X
- BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X
- BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X
- Aggregate HashMap 121 / 131 173.3 5.8 2.2X
- */
+ hash 112 / 116 93.2 10.7 1.0X
+ fast hash 65 / 69 160.9 6.2 1.7X
+ arrayEqual 66 / 69 159.1 6.3 1.7X
+ Java HashMap (Long) 137 / 182 76.3 13.1 0.8X
+ Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X
+ Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X
+ BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X
+ BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X
+ Aggregate HashMap 56 / 62 187.9 5.3 2.0X
+ */
benchmark.run()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 17f2343cf9..9680f3a008 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
test("compatible BroadcastMode") {
val mode1 = IdentityBroadcastMode
- val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
- val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)
+ val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq())
+ val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq())
assert(mode1.compatibleWith(mode1))
assert(!mode1.compatibleWith(mode2))
@@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
assert(plan sameResult plan)
val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan)
- val hashMode = HashedRelationBroadcastMode(output)
+ val hashMode = HashedRelationBroadcastMode(true, output, plan.output)
val exchange2 = BroadcastExchange(hashMode, plan)
val hashMode2 =
- HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil)
+ HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output)
val exchange3 = BroadcastExchange(hashMode2, plan)
val exchange4 = ReusedExchange(output, exchange3)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 371a9ed617..ed87a99439 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -30,23 +30,15 @@ import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
- val mm = new TaskMemoryManager(
- new StaticMemoryManager(
- new SparkConf().set("spark.memory.offHeap.enabled", "false"),
- Long.MaxValue,
- Long.MaxValue,
- 1),
- 0)
-
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy())
-
val buildKey = Seq(BoundReference(0, IntegerType, false))
- val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm)
+ val keyGenerator = UnsafeProjection.create(buildKey)
+ val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
assert(hashed.isInstanceOf[UnsafeHashedRelation])
assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0)))
@@ -108,45 +100,31 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
- test("LongToUnsafeRowMap") {
+ test("LongArrayRelation") {
val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
- val key = Seq(BoundReference(0, IntegerType, false))
- val longRelation = LongHashedRelation(rows.iterator, key, 10, mm)
- assert(longRelation.keyIsUnique)
+ val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false)))
+ val longRelation = LongHashedRelation(rows.iterator, keyProj, 100)
+ assert(longRelation.isInstanceOf[LongArrayRelation])
+ val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation]
(0 until 100).foreach { i =>
- val row = longRelation.getValue(i)
+ val row = longArrayRelation.getValue(i)
assert(row.getInt(0) === i)
assert(row.getInt(1) === i + 1)
}
- val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
- assert(!longRelation2.keyIsUnique)
- (0 until 100).foreach { i =>
- val rows = longRelation2.get(i).toArray
- assert(rows.length === 2)
- assert(rows(0).getInt(0) === i)
- assert(rows(0).getInt(1) === i + 1)
- assert(rows(1).getInt(0) === i)
- assert(rows(1).getInt(1) === i + 1)
- }
-
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
- longRelation2.writeExternal(out)
+ longArrayRelation.writeExternal(out)
out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
- val relation = new LongHashedRelation()
+ val relation = new LongArrayRelation()
relation.readExternal(in)
- assert(!relation.keyIsUnique)
(0 until 100).foreach { i =>
- val rows = relation.get(i).toArray
- assert(rows.length === 2)
- assert(rows(0).getInt(0) === i)
- assert(rows(0).getInt(1) === i + 1)
- assert(rows(1).getInt(0) === i)
- assert(rows(1).getInt(1) === i + 1)
+ val row = longArrayRelation.getValue(i)
+ assert(row.getInt(0) === i)
+ assert(row.getInt(1) === i + 1)
}
}
}