diff options
author | Davies Liu <davies@databricks.com> | 2016-04-09 00:37:55 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-04-09 00:37:55 -0700 |
commit | 90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f (patch) | |
tree | bb0c38896d7b02e4dd612a68f22bcc982d383d08 | |
parent | 520dde48d0d52dbbbbe1710a3275fdd5355dd69d (diff) | |
download | spark-90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f.tar.gz spark-90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f.tar.bz2 spark-90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f.zip |
[SPARK-14419] [SQL] Improve HashedRelation for key fit within Long
## What changes were proposed in this pull request?
Currently, we use java HashMap for HashedRelation if the key could fit within a Long. The java HashMap and CompactBuffer are not memory efficient, the memory used by them is also accounted accurately.
This PR introduce a LongToUnsafeRowMap (similar to BytesToBytesMap) for better memory efficiency and performance.
## How was this patch tested?
Updated existing tests.
Author: Davies Liu <davies@databricks.com>
Closes #12190 from davies/long_map2.
8 files changed, 633 insertions, 346 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 0a5a72c52a..692fef703f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -454,7 +454,7 @@ case class TungstenAggregate( val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + ctx.addMutableState(hashMapClassName, hashMapTerm, s"") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") @@ -467,6 +467,7 @@ case class TungstenAggregate( s""" ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { + $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index e3d554c2de..a8f854136c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.LongType /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -50,10 +51,7 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode( - canJoinKeyFitWithinLong, - rewriteKeyExpr(buildKeys), - buildPlan.output) + val mode = HashedRelationBroadcastMode(buildKeys) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -68,7 +66,7 @@ case class BroadcastHashJoin( val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) join(streamedIter, hashed, numOutputRows) } } @@ -105,7 +103,7 @@ case class BroadcastHashJoin( ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.getMemorySize()); + | incPeakExecutionMemory($relationTerm.estimatedSize()); """.stripMargin) (broadcastRelation, relationTerm) } @@ -118,15 +116,13 @@ case class BroadcastHashJoin( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - if (canJoinKeyFitWithinLong) { + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { // generate the join key as Long - val expr = rewriteKeyExpr(streamedKeys).head - val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) + val ev = streamedKeys.head.gen(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow - val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) - val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) (ev, s"${ev.value}.anyNull()") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8f45d57126..4c912d371e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -59,9 +59,13 @@ trait HashJoin { case BuildRight => (right, left) } - protected lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) + protected lazy val (buildKeys, streamedKeys) = { + val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + buildSide match { + case BuildLeft => (lkeys, rkeys) + case BuildRight => (rkeys, lkeys) + } } /** @@ -84,17 +88,8 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 - // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same - // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys - // with two same ints have hash code 0, we rotate the bits of second one. - val rotated = if (e.dataType == IntegerType) { - // (e >>> 15) | (e << 17) - BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) - } else { - e - } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType @@ -105,17 +100,11 @@ trait HashJoin { keyExpr :: Nil } - protected lazy val canJoinKeyFitWithinLong: Boolean = { - val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) - val key = rewriteKeyExpr(buildKeys) - sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] - } - protected def buildSideKeyGenerator(): Projection = - UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) + UnsafeProjection.create(buildKeys) protected def streamSideKeyGenerator(): UnsafeProjection = - UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) + UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 5ccb435686..4959f60dab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,24 +18,22 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} -import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} -import org.apache.spark.util.collection.CompactBuffer /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[execution] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation extends KnownSizeEstimation { /** * Returns matched rows. * @@ -75,50 +73,35 @@ private[execution] sealed trait HashedRelation { def asReadOnlyCopy(): HashedRelation /** - * Returns the size of used memory. - */ - def getMemorySize: Long = 1L // to make the test happy - - /** * Release any used resources. */ - def close(): Unit = {} - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { - out.writeInt(serialized.length) // Write the length of serialized bytes first - out.write(serialized) - } - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def readBytes(in: ObjectInput): Array[Byte] = { - val serializedSize = in.readInt() // Read the length of serialized bytes first - val bytes = new Array[Byte](serializedSize) - in.readFully(bytes) - bytes - } + def close(): Unit } private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. - * - * Note: The caller should make sure that these InternalRow are different objects. */ def apply( - canJoinKeyFitWithinLong: Boolean, input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int = 64): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int = 64, + taskMemoryManager: TaskMemoryManager = null): HashedRelation = { + val mm = Option(taskMemoryManager).getOrElse { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + } - if (canJoinKeyFitWithinLong) { - LongHashedRelation(input, keyGenerator, sizeEstimate) + if (key.length == 1 && key.head.dataType == LongType) { + LongHashedRelation(input, key, sizeEstimate, mm) } else { - UnsafeHashedRelation( - input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + UnsafeHashedRelation(input, key, sizeEstimate, mm) } } } @@ -133,7 +116,7 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numFields: Int, private var binaryMap: BytesToBytesMap) - extends HashedRelation with KnownSizeEstimation with Externalizable { + extends HashedRelation with Externalizable { private[joins] def this() = this(0, null) // Needed for serialization @@ -142,10 +125,6 @@ private[joins] class UnsafeHashedRelation( override def asReadOnlyCopy(): UnsafeHashedRelation = new UnsafeHashedRelation(numFields, binaryMap) - override def getMemorySize: Long = { - binaryMap.getTotalMemoryConsumption - } - override def estimatedSize: Long = { binaryMap.getTotalMemoryConsumption } @@ -276,20 +255,10 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - keyGenerator: UnsafeProjection, - sizeEstimate: Int): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): HashedRelation = { - val taskMemoryManager = if (TaskContext.get() != null) { - TaskContext.get().taskMemoryManager() - } else { - new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - } val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) @@ -300,6 +269,7 @@ private[joins] object UnsafeHashedRelation { pageSizeBytes) // Create a mapping of buildKeys -> rows + val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] @@ -321,144 +291,471 @@ private[joins] object UnsafeHashedRelation { } } +private[joins] object LongToUnsafeRowMap { + // the largest prime that below 2^n + val LARGEST_PRIMES = { + // https://primes.utm.edu/lists/2small/0bit.html + val diffs = Seq( + 0, 1, 1, 3, 1, 3, 1, 5, + 3, 3, 9, 3, 1, 3, 19, 15, + 1, 5, 1, 3, 9, 3, 15, 3, + 39, 5, 39, 57, 3, 35, 1, 5 + ) + val primes = new Array[Int](32) + primes(0) = 1 + var power2 = 1 + (1 until 32).foreach { i => + power2 *= 2 + primes(i) = power2 - diffs(i) + } + primes + } +} + /** - * An interface for a hashed relation that the key is a Long. + * An append-only hash map mapping from key of Long to UnsafeRow. + * + * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array + * (`page`) in this format: + * + * [bytes of row1][address1][bytes of row2][address1] ... + * + * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key + * could have multiple values. the address at the end of last value for every key is 0. + * + * The keys and addresses of their values could be stored in two modes: + * + * 1) sparse mode: the keys and addresses are stored in `array` as: + * + * [key1][address1][key2][address2]...[] + * + * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1 + * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address + * hash collision. + * + * 2) dense mode: all the addresses are packed into a single array of long, as: + * + * [address1] [address2] ... + * + * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is + * determined by `key1 - minKey`. + * + * The map is created as sparse mode, then key-value could be appended into it. Once finish + * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * to probe. */ -private[joins] trait LongHashedRelation extends HashedRelation { - override def get(key: InternalRow): Iterator[InternalRow] = { - get(key.getLong(0)) +private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) + extends MemoryConsumer(mm) with Externalizable { + import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._ + + // Whether the keys are stored in dense mode or not. + private var isDense = false + + // The minimum value of keys. + private var minKey = Long.MaxValue + + // The Maxinum value of keys. + private var maxKey = Long.MinValue + + // Sparse mode: the actual capacity of map, is a prime number. + private var cap: Int = 0 + + // The array to store the key and offset of UnsafeRow in the page. + // + // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... + // Dense mode: [offset1 | size1] [offset2 | size2] + private var array: Array[Long] = null + + // The page to store all bytes of UnsafeRow and the pointer to next rows. + // [row1][pointer1] [row2][pointer2] + private var page: Array[Byte] = null + + // Current write cursor in the page. + private var cursor = Platform.BYTE_ARRAY_OFFSET + + // The total number of values of all keys. + private var numValues = 0 + + // The number of unique keys. + private var numKeys = 0 + + // needed by serializer + def this() = { + this( + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0), + 0) } - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) + + private def acquireMemory(size: Long): Unit = { + // do not support spilling + val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) + if (got < size) { + mm.releaseExecutionMemory(got, MemoryMode.ON_HEAP, this) + throw new SparkException(s"Can't acquire $size bytes memory to build hash relation") + } } -} -private[joins] final class GeneralLongHashedRelation( - private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) - extends LongHashedRelation with Externalizable { + private def freeMemory(size: Long): Unit = { + mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) + } - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) + private def init(): Unit = { + if (mm != null) { + cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{ + sys.error(s"Can't create map with capacity $capacity") + } + acquireMemory(cap * 2 * 8 + (1 << 20)) + array = new Array[Long](cap * 2) + page = new Array[Byte](1 << 20) // 1M bytes + } + } + + init() + + def spill(size: Long, trigger: MemoryConsumer): Long = { + 0L + } + + /** + * Returns whether all the keys are unique. + */ + def keyIsUnique: Boolean = numKeys == numValues + + /** + * Returns total memory consumption. + */ + def getTotalMemoryConsumption: Long = { + array.length * 8 + page.length + } - override def keyIsUnique: Boolean = false + /** + * Returns the slot of array that store the keys (sparse mode). + */ + private def getSlot(key: Long): Int = { + var s = (key % cap).toInt + if (s < 0) { + s += cap + } + s * 2 + } - override def asReadOnlyCopy(): GeneralLongHashedRelation = - new GeneralLongHashedRelation(hashTable) + private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { + val offset = address >>> 32 + val size = address & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + resultRow + } - override def get(key: Long): Iterator[InternalRow] = { - val rows = hashTable.get(key) - if (rows != null) { - rows.toIterator + /** + * Returns the single UnsafeRow for given key, or null if not found. + */ + def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >= 0 && key <= maxKey && array(idx) > 0) { + return getRow(array(idx), resultRow) + } } else { - null + var pos = getSlot(key) + var step = 1 + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return getRow(array(pos + 1), resultRow) + } + pos += 2 * step + step += 1 + if (pos >= array.length) { + pos -= array.length + } + } + } + null + } + + /** + * Returns an interator of UnsafeRow for multiple linked values. + */ + private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + new Iterator[UnsafeRow] { + var addr = address + override def hasNext: Boolean = addr != 0 + override def next(): UnsafeRow = { + val offset = addr >>> 32 + val size = addr & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + addr = Platform.getLong(page, offset + size) + resultRow + } + } + } + + /** + * Returns an iterator for all the values for the given key, or null if no value found. + */ + def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >=0 && key <= maxKey && array(idx) > 0) { + return valueIter(array(idx), resultRow) + } + } else { + var pos = getSlot(key) + var step = 1 + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return valueIter(array(pos + 1), resultRow) + } + pos += 2 * step + step += 1 + if (pos >= array.length) { + pos -= array.length + } + } + } + null + } + + /** + * Appends the key and row into this map. + */ + def append(key: Long, row: UnsafeRow): Unit = { + if (key < minKey) { + minKey = key + } + if (key > maxKey) { + maxKey = key + } + + // There is 8 bytes for the pointer to next value + if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { + val used = page.length + if (used * 2L > (1L << 31)) { + sys.error("Can't allocate a page that is larger than 2G") + } + acquireMemory(used * 2) + val newPage = new Array[Byte](used * 2) + System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) + page = newPage + freeMemory(used) + } + + // copy the bytes of UnsafeRow + val offset = cursor + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + cursor += row.getSizeInBytes + Platform.putLong(page, cursor, 0) + cursor += 8 + numValues += 1 + updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) + } + + /** + * Update the address in array for given key. + */ + private def updateIndex(key: Long, address: Long): Unit = { + var pos = getSlot(key) + var step = 1 + while (array(pos + 1) != 0 && array(pos) != key) { + pos += 2 * step + step += 1 + if (pos >= array.length) { + pos -= array.length + } + } + if (array(pos + 1) == 0) { + // this is the first value for this key, put the address in array. + array(pos) = key + array(pos + 1) = address + numKeys += 1 + if (numKeys * 2 > cap) { + // reach half of the capacity + growArray() + } + } else { + // there is another value for this key, put the address at the end of final value. + var addr = array(pos + 1) + var pointer = (addr >>> 32) + (addr & 0xffffffffL) + while (Platform.getLong(page, pointer) != 0) { + addr = Platform.getLong(page, pointer) + pointer = (addr >>> 32) + (addr & 0xffffffffL) + } + Platform.putLong(page, pointer, address) + } + } + + private def growArray(): Unit = { + val old_cap = cap + var old_array = array + cap = LARGEST_PRIMES.find(_ > cap).getOrElse{ + sys.error(s"Can't grow map any more than $cap") + } + numKeys = 0 + acquireMemory(cap * 2 * 8) + array = new Array[Long](cap * 2) + var i = 0 + while (i < old_array.length) { + if (old_array(i + 1) > 0) { + updateIndex(old_array(i), old_array(i + 1)) + } + i += 2 + } + old_array = null // release the reference to old array + freeMemory(old_cap * 2 * 8) + } + + /** + * Try to turn the map into dense mode, which is faster to probe. + */ + def optimize(): Unit = { + val range = maxKey - minKey + // Convert to dense mode if it does not require more memory or could fit within L1 cache + if (range < array.length || range < 1024) { + try { + acquireMemory((range + 1) * 8) + } catch { + case e: SparkException => + // there is no enough memory to convert + return + } + val denseArray = new Array[Long]((range + 1).toInt) + var i = 0 + while (i < array.length) { + if (array(i + 1) > 0) { + val idx = (array(i) - minKey).toInt + denseArray(idx) = array(i + 1) + } + i += 2 + } + val old_length = array.length + array = denseArray + isDense = true + freeMemory(old_length * 8) + } + } + + /** + * Free all the memory acquired by this map. + */ + def free(): Unit = { + if (page != null) { + freeMemory(page.length) + page = null + } + if (array != null) { + freeMemory(array.length * 8) + array = null } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeBoolean(isDense) + out.writeLong(minKey) + out.writeLong(maxKey) + out.writeInt(numKeys) + out.writeInt(numValues) + out.writeInt(cap) + + out.writeInt(array.length) + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + out.write(buffer, 0, size) + offset += size + } + + val used = cursor - Platform.BYTE_ARRAY_OFFSET + out.writeInt(used) + out.write(page, 0, used) } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + isDense = in.readBoolean() + minKey = in.readLong() + maxKey = in.readLong() + numKeys = in.readInt() + numValues = in.readInt() + cap = in.readInt() + + val length = in.readInt() + array = new Array[Long](length) + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + in.readFully(buffer, 0, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + offset += size + } + + val numBytes = in.readInt() + page = new Array[Byte](numBytes) + in.readFully(page) } } -/** - * A relation that pack all the rows into a byte array, together with offsets and sizes. - * - * All the bytes of UnsafeRow are packed together as `bytes`: - * - * [ Row0 ][ Row1 ][] ... [ RowN ] - * - * With keys: - * - * start start+1 ... start+N - * - * `offsets` are offsets of UnsafeRows in the `bytes` - * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. - * - * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: - * - * start = 3 - * offsets = [0, 0, 24] - * sizes = [24, 0, 32] - * bytes = [0 - 24][][24 - 56] - */ -private[joins] final class LongArrayRelation( - private var numFields: Int, - private var start: Long, - private var offsets: Array[Int], - private var sizes: Array[Int], - private var bytes: Array[Byte] - ) extends LongHashedRelation with Externalizable { +private[joins] class LongHashedRelation( + private var nFields: Int, + private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { + + private var resultRow: UnsafeRow = new UnsafeRow(nFields) // Needed for serialization (it is public to make Java serialization work) - def this() = this(0, 0L, null, null, null) + def this() = this(0, null) - override def keyIsUnique: Boolean = true + override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) - override def asReadOnlyCopy(): LongArrayRelation = { - new LongArrayRelation(numFields, start, offsets, sizes, bytes) + override def estimatedSize: Long = { + map.getTotalMemoryConsumption } - override def getMemorySize: Long = { - offsets.length * 4 + sizes.length * 4 + bytes.length + override def get(key: InternalRow): Iterator[InternalRow] = { + if (key.isNullAt(0)) { + null + } else { + get(key.getLong(0)) + } } - override def get(key: Long): Iterator[InternalRow] = { - val row = getValue(key) - if (row != null) { - Seq(row).toIterator - } else { + override def getValue(key: InternalRow): InternalRow = { + if (key.isNullAt(0)) { null + } else { + getValue(key.getLong(0)) } } - var resultRow = new UnsafeRow(numFields) + override def get(key: Long): Iterator[InternalRow] = + map.get(key, resultRow) + override def getValue(key: Long): InternalRow = { - val idx = (key - start).toInt - if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { - resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) - resultRow - } else { - null - } + map.getValue(key, resultRow) + } + + override def keyIsUnique: Boolean = map.keyIsUnique + + override def close(): Unit = { + map.free() } override def writeExternal(out: ObjectOutput): Unit = { - out.writeInt(numFields) - out.writeLong(start) - out.writeInt(sizes.length) - var i = 0 - while (i < sizes.length) { - out.writeInt(sizes(i)) - i += 1 - } - out.writeInt(bytes.length) - out.write(bytes) + out.writeInt(nFields) + out.writeObject(map) } override def readExternal(in: ObjectInput): Unit = { - numFields = in.readInt() - resultRow = new UnsafeRow(numFields) - start = in.readLong() - val length = in.readInt() - // read sizes of rows - sizes = new Array[Int](length) - offsets = new Array[Int](length) - var i = 0 - var offset = 0 - while (i < length) { - offsets(i) = offset - sizes(i) = in.readInt() - offset += sizes(i) - i += 1 - } - // read all the bytes - val total = in.readInt() - assert(total == offset) - bytes = new Array[Byte](total) - in.readFully(bytes) + nFields = in.readInt() + resultRow = new UnsafeRow(nFields) + map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } } @@ -466,96 +763,45 @@ private[joins] final class LongArrayRelation( * Create hashed relation with key that is long. */ private[joins] object LongHashedRelation { - - val DENSE_FACTOR = 0.2 - def apply( - input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int): HashedRelation = { + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): LongHashedRelation = { - // TODO: use LongToBytesMap for better memory efficiency - val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) + val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) + val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows var numFields = 0 - var keyIsUnique = true - var minKey = Long.MaxValue - var maxKey = Long.MinValue while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) - if (!rowKey.anyNull) { + if (!rowKey.isNullAt(0)) { val key = rowKey.getLong(0) - minKey = math.min(minKey, key) - maxKey = math.max(maxKey, key) - val existingMatchList = hashTable.get(key) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[UnsafeRow]() - hashTable.put(key, newMatchList) - newMatchList - } else { - keyIsUnique = false - existingMatchList - } - matchList += unsafeRow - } - } - - if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { - // The keys are dense enough, so use LongArrayRelation - val length = (maxKey - minKey).toInt + 1 - val sizes = new Array[Int](length) - val offsets = new Array[Int](length) - var offset = 0 - var i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - offsets(i) = offset - sizes(i) = rows(0).getSizeInBytes - offset += sizes(i) - } - i += 1 - } - val bytes = new Array[Byte](offset) - i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) - } - i += 1 + map.append(key, unsafeRow) } - new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) - } else { - new GeneralLongHashedRelation(hashTable) } + map.optimize() + new LongHashedRelation(numFields, map) } } /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ -private[execution] case class HashedRelationBroadcastMode( - canJoinKeyFitWithinLong: Boolean, - keys: Seq[Expression], - attributes: Seq[Attribute]) extends BroadcastMode { +private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) + extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - val generator = UnsafeProjection.create(keys, attributes) - HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length) + HashedRelation(rows.iterator, canonicalizedKey, rows.length) } - private lazy val canonicalizedKeys: Seq[Expression] = { - keys.map { e => - BindReferences.bindReference(e.canonicalized, attributes) - } + private lazy val canonicalizedKey: Seq[Expression] = { + key.map { e => e.canonicalized } } override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => - canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && - canonicalizedKeys == m.canonicalizedKeys + case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index bf86096379..0c3e3c3fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.memory.MemoryMode +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -57,54 +56,20 @@ case class ShuffledHashJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { + private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val context = TaskContext.get() - if (!canJoinKeyFitWithinLong) { - // build BytesToBytesMap - val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator) - // This relation is usually used until the end of task. - context.addTaskCompletionListener((t: TaskContext) => - relation.close() - ) - return relation - } - - // try to acquire some memory for the hash table, it could trigger other operator to free some - // memory. The memory acquired here will mostly be used until the end of task. - val memoryManager = context.taskMemoryManager() - var acquired = 0L - var used = 0L + val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + // This relation is usually used until the end of task. context.addTaskCompletionListener((t: TaskContext) => - memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) + relation.close() ) - - val copiedIter = iter.map { row => - // It's hard to guess what's exactly memory will be used, we have a rough guess here. - // TODO: use LongToBytesMap instead of HashMap for memory efficiency - // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers - val needed = 150 + row.getSizeInBytes - if (needed > acquired - used) { - val got = memoryManager.acquireExecutionMemory( - Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) - acquired += got - if (got < needed) { - throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + - "hash join, please use sort merge join by setting " + - "spark.sql.join.preferSortMergeJoin=true") - } - } - used += needed - // HashedRelation requires that the UnsafeRow should be separate objects. - row.copy() - } - - HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator) + relation } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => - val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) + val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 5dbf619876..352fd07d0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,6 +21,7 @@ import java.util.HashMap import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.AggregateHashMap @@ -179,8 +180,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X - Join w long codegen=true 275 / 352 76.2 13.1 19.4X + Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X + Join w long codegen=true 321 / 371 65.3 15.3 9.3X */ runBenchmark("Join w long duplicated", N) { @@ -193,8 +194,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X - Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X + Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X + Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X */ val dim2 = broadcast(sqlContext.range(M) @@ -211,8 +212,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X - Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X + Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X + Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X */ val dim3 = broadcast(sqlContext.range(M) @@ -259,8 +260,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X - outer join w long codegen=true 216 / 226 97.2 10.3 26.3X + outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X + outer join w long codegen=true 261 / 276 80.5 12.4 11.7X */ runBenchmark("semi join w long", N) { @@ -272,8 +273,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X - semi join w long codegen=true 211 / 229 99.2 10.1 22.2X + semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X + semi join w long codegen=true 237 / 244 88.3 11.3 8.1X */ } @@ -326,8 +327,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X - shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X + shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X + shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X */ } @@ -349,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 10 << 20 + val N = 20 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) - benchmark.addCase("hash") { iter => + benchmark.addCase("UnsafeRowhash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) @@ -368,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("murmur3 hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = Murmur3_x86_32.hashLong(i, 42) + key.setInt(0, h) + s += h + i += 1 + } + } + benchmark.addCase("fast hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 var s = 0 while (i < N) { - key.setInt(0, i % 1000) - val h = Murmur3_x86_32.hashLong(i % 1000, 42) + var h = i % p + if (h < 0) { + h += p + } + key.setInt(0, h) s += h i += 1 } @@ -475,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + Seq(false, true).foreach { optimized => + benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new LongToUnsafeRowMap(taskMemoryManager, 64) + while (i < 65536) { + value.setInt(0, i) + val key = i % 100000 + map.append(key, value) + i += 1 + } + if (optimized) { + map.optimize() + } + var s = 0 + i = 0 + while (i < N) { + val key = i % 100000 + if (map.getValue(key, value) != null) { + s += 1 + } + i += 1 + } + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -493,18 +549,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - while (i < N) { + val numKeys = 65536 + while (i < numKeys) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, Murmur3_x86_32.hashLong(i % 65536, 42)) - if (loc.isDefined) { - value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - value.setInt(0, value.getInt(0) + 1) - i += 1 - } else { + if (!loc.isDefined) { loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) } + i += 1 + } + i = 0 + var s = 0 + while (i < N) { + key.setInt(0, i % 100000) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 100000, 42)) + if (loc.isDefined) { + s += 1 + } + i += 1 } } } @@ -535,16 +600,19 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - hash 112 / 116 93.2 10.7 1.0X - fast hash 65 / 69 160.9 6.2 1.7X - arrayEqual 66 / 69 159.1 6.3 1.7X - Java HashMap (Long) 137 / 182 76.3 13.1 0.8X - Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X - Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X - BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X - BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X - Aggregate HashMap 56 / 62 187.9 5.3 2.0X - */ + UnsafeRow hash 267 / 284 78.4 12.8 1.0X + murmur3 hash 102 / 129 205.5 4.9 2.6X + fast hash 79 / 96 263.8 3.8 3.4X + arrayEqual 164 / 172 128.2 7.8 1.6X + Java HashMap (Long) 321 / 399 65.4 15.3 0.8X + Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X + Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X + LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X + LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X + BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X + BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X + Aggregate HashMap 121 / 131 173.3 5.8 2.2X + */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 9680f3a008..17f2343cf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode - val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) - val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) + val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) + val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) @@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) - val hashMode = HashedRelationBroadcastMode(true, output, plan.output) + val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchange(hashMode, plan) val hashMode2 = - HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) + HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchange(hashMode2, plan) val exchange4 = ReusedExchange(output, exchange3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ed87a99439..371a9ed617 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,15 +30,23 @@ import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { + val mm = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()) + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) @@ -100,31 +108,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } - test("LongArrayRelation") { + test("LongToUnsafeRowMap") { val unsafeProj = UnsafeProjection.create( Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) - val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) - val longRelation = LongHashedRelation(rows.iterator, keyProj, 100) - assert(longRelation.isInstanceOf[LongArrayRelation]) - val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] + val key = Seq(BoundReference(0, IntegerType, false)) + val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) + assert(longRelation.keyIsUnique) (0 until 100).foreach { i => - val row = longArrayRelation.getValue(i) + val row = longRelation.getValue(i) assert(row.getInt(0) === i) assert(row.getInt(1) === i + 1) } + val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm) + assert(!longRelation2.keyIsUnique) + (0 until 100).foreach { i => + val rows = longRelation2.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) + } + val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) - longArrayRelation.writeExternal(out) + longRelation2.writeExternal(out) out.flush() val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) - val relation = new LongArrayRelation() + val relation = new LongHashedRelation() relation.readExternal(in) + assert(!relation.keyIsUnique) (0 until 100).foreach { i => - val row = longArrayRelation.getValue(i) - assert(row.getInt(0) === i) - assert(row.getInt(1) === i + 1) + val rows = relation.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) } } } |