aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala96
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala297
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala29
8 files changed, 438 insertions, 69 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 131efea20f..4ca2d85406 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -38,6 +38,7 @@ trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: TungstenAggregate => "agg"
+ case _: BroadcastHashJoin => "bhj"
case _ => nodeName.toLowerCase
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 943ad31c0c..cbd549763a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -90,8 +90,14 @@ case class BroadcastHashJoin(
// The following line doesn't run in a job so we cannot track the metric value. However, we
// have already tracked it in the above lines. So here we can use
// `SQLMetrics.nullLongMetric` to ignore it.
- val hashed = HashedRelation(
- input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+ // TODO: move this check into HashedRelation
+ val hashed = if (canJoinKeyFitWithinLong) {
+ LongHashedRelation(
+ input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+ } else {
+ HashedRelation(
+ input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+ }
sparkContext.broadcast(hashed)
}
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
@@ -112,15 +118,12 @@ case class BroadcastHashJoin(
streamedPlan.execute().mapPartitions { streamedIter =>
val hashedRelation = broadcastRelation.value
- hashedRelation match {
- case unsafe: UnsafeHashedRelation =>
- TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
- case _ =>
- }
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
}
}
+ private var broadcastRelation: Broadcast[HashedRelation] = _
// the term for hash relation
private var relationTerm: String = _
@@ -129,16 +132,15 @@ case class BroadcastHashJoin(
}
override def doProduce(ctx: CodegenContext): String = {
- // create a name for HashRelation
- val broadcastRelation = Await.result(broadcastFuture, timeout)
+ // create a name for HashedRelation
+ broadcastRelation = Await.result(broadcastFuture, timeout)
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
relationTerm = ctx.freshName("relation")
- // TODO: create specialized HashRelation for single join key
- val clsName = classOf[UnsafeHashedRelation].getName
+ val clsName = broadcastRelation.value.getClass.getName
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = ($clsName) $broadcast.value();
- | incPeakExecutionMemory($relationTerm.getUnsafeSize());
+ | incPeakExecutionMemory($relationTerm.getMemorySize());
""".stripMargin)
s"""
@@ -147,23 +149,24 @@ case class BroadcastHashJoin(
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- // generate the key as UnsafeRow
+ // generate the key as UnsafeRow or Long
ctx.currentVars = input
- val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
- val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
- val keyTerm = keyVal.value
- val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"
+ val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
+ val expr = rewriteKeyExpr(streamedKeys).head
+ val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
+ (ev, ev.isNull)
+ } else {
+ val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
+ val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+ (ev, s"${ev.value}.anyNull()")
+ }
// find the matches from HashedRelation
- val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
- val row = ctx.freshName("row")
+ val matched = ctx.freshName("matched")
// create variables for output
ctx.currentVars = null
- ctx.INPUT_ROW = row
+ ctx.INPUT_ROW = matched
val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).gen(ctx)
}
@@ -172,7 +175,7 @@ case class BroadcastHashJoin(
case BuildRight => input ++ buildColumns
}
- val ouputCode = if (condition.isDefined) {
+ val outputCode = if (condition.isDefined) {
// filter the output via condition
ctx.currentVars = resultVars
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
@@ -186,20 +189,39 @@ case class BroadcastHashJoin(
consume(ctx, resultVars)
}
- s"""
- | // generate join key
- | ${keyVal.code}
- | // find matches from HashRelation
- | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
- | if ($matches != null) {
- | int $size = $matches.size();
- | for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $row = (UnsafeRow) $matches.apply($i);
- | ${buildColumns.map(_.code).mkString("\n")}
- | $ouputCode
- | }
- | }
+ if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ s"""
+ | // generate join key
+ | ${keyVal.code}
+ | // find matches from HashedRelation
+ | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
+ | if ($matched != null) {
+ | ${buildColumns.map(_.code).mkString("\n")}
+ | $outputCode
+ | }
""".stripMargin
+
+ } else {
+ val matches = ctx.freshName("matches")
+ val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val i = ctx.freshName("i")
+ val size = ctx.freshName("size")
+ s"""
+ | // generate join key
+ | ${keyVal.code}
+ | // find matches from HashRelation
+ | $bufferType $matches = ${anyNull} ? null :
+ | ($bufferType) $relationTerm.get(${keyVal.value});
+ | if ($matches != null) {
+ | int $size = $matches.size();
+ | for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ | ${buildColumns.map(_.code).mkString("\n")}
+ | $outputCode
+ | }
+ | }
+ """.stripMargin
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index f48fc3b848..ad3275696e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -116,12 +116,7 @@ case class BroadcastHashOuterJoin(
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
val keyGenerator = streamedKeyGenerator
-
- hashTable match {
- case unsafe: UnsafeHashedRelation =>
- TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
- case _ =>
- }
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
val resultProj = resultProjection
joinType match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 8929dc3af1..d0e18dfcf3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -64,11 +64,7 @@ case class BroadcastLeftSemiJoinHash(
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
- hashedRelation match {
- case unsafe: UnsafeHashedRelation =>
- TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
- case _ =>
- }
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 8ef854001f..ecbb1ac64b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.LongSQLMetric
-
+import org.apache.spark.sql.types.{IntegralType, LongType}
trait HashJoin {
self: SparkPlan =>
@@ -47,11 +47,49 @@ trait HashJoin {
override def output: Seq[Attribute] = left.output ++ right.output
+ /**
+ * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
+ *
+ * If not, returns the original expressions.
+ */
+ def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
+ var keyExpr: Expression = null
+ var width = 0
+ keys.foreach { e =>
+ e.dataType match {
+ case dt: IntegralType if dt.defaultSize <= 8 - width =>
+ if (width == 0) {
+ if (e.dataType != LongType) {
+ keyExpr = Cast(e, LongType)
+ } else {
+ keyExpr = e
+ }
+ width = dt.defaultSize
+ } else {
+ val bits = dt.defaultSize * 8
+ keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
+ BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+ width -= bits
+ }
+ // TODO: support BooleanType, DateType and TimestampType
+ case other =>
+ return keys
+ }
+ }
+ keyExpr :: Nil
+ }
+
+ protected val canJoinKeyFitWithinLong: Boolean = {
+ val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
+ val key = rewriteKeyExpr(buildKeys)
+ sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
+ }
+
protected def buildSideKeyGenerator: Projection =
- UnsafeProjection.create(buildKeys, buildPlan.output)
+ UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
protected def streamSideKeyGenerator: Projection =
- UnsafeProjection.create(streamedKeys, streamedPlan.output)
+ UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index ee7a1bdc34..c94d6c195b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -39,8 +39,23 @@ import org.apache.spark.util.collection.CompactBuffer
* object.
*/
private[execution] sealed trait HashedRelation {
+ /**
+ * Returns matched rows.
+ */
def get(key: InternalRow): Seq[InternalRow]
+ /**
+ * Returns matched rows for a key that has only one column with LongType.
+ */
+ def get(key: Long): Seq[InternalRow] = {
+ throw new UnsupportedOperationException
+ }
+
+ /**
+ * Returns the size of used memory.
+ */
+ def getMemorySize: Long = 1L // to make the test happy
+
// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
@@ -58,11 +73,48 @@ private[execution] sealed trait HashedRelation {
}
}
+/**
+ * Interface for a hashed relation that have only one row per key.
+ *
+ * We should call getValue() for better performance.
+ */
+private[execution] trait UniqueHashedRelation extends HashedRelation {
+
+ /**
+ * Returns the matched single row.
+ */
+ def getValue(key: InternalRow): InternalRow
+
+ /**
+ * Returns the matched single row with key that have only one column of LongType.
+ */
+ def getValue(key: Long): InternalRow = {
+ throw new UnsupportedOperationException
+ }
+
+ override def get(key: InternalRow): Seq[InternalRow] = {
+ val row = getValue(key)
+ if (row != null) {
+ CompactBuffer[InternalRow](row)
+ } else {
+ null
+ }
+ }
+
+ override def get(key: Long): Seq[InternalRow] = {
+ val row = getValue(key)
+ if (row != null) {
+ CompactBuffer[InternalRow](row)
+ } else {
+ null
+ }
+ }
+}
/**
* A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
*/
-private[joins] final class GeneralHashedRelation(
+private[joins] class GeneralHashedRelation(
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {
@@ -85,19 +137,14 @@ private[joins] final class GeneralHashedRelation(
* A specialized [[HashedRelation]] that maps key into a single value. This implementation
* assumes the key is unique.
*/
-private[joins]
-final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
- extends HashedRelation with Externalizable {
+private[joins] class UniqueKeyHashedRelation(
+ private var hashTable: JavaHashMap[InternalRow, InternalRow])
+ extends UniqueHashedRelation with Externalizable {
// Needed for serialization (it is public to make Java serialization work)
def this() = this(null)
- override def get(key: InternalRow): Seq[InternalRow] = {
- val v = hashTable.get(key)
- if (v eq null) null else CompactBuffer(v)
- }
-
- def getValue(key: InternalRow): InternalRow = hashTable.get(key)
+ override def getValue(key: InternalRow): InternalRow = hashTable.get(key)
override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
@@ -108,8 +155,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
}
}
-// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
-
private[execution] object HashedRelation {
@@ -208,7 +253,7 @@ private[joins] final class UnsafeHashedRelation(
*
* For non-broadcast joins or in local mode, return 0.
*/
- def getUnsafeSize: Long = {
+ override def getMemorySize: Long = {
if (binaryMap != null) {
binaryMap.getTotalMemoryConsumption
} else {
@@ -408,6 +453,232 @@ private[joins] object UnsafeHashedRelation {
}
}
+ // TODO: create UniqueUnsafeRelation
new UnsafeHashedRelation(hashTable)
}
}
+
+/**
+ * An interface for a hashed relation that the key is a Long.
+ */
+private[joins] trait LongHashedRelation extends HashedRelation {
+ override def get(key: InternalRow): Seq[InternalRow] = {
+ get(key.getLong(0))
+ }
+}
+
+private[joins] final class GeneralLongHashedRelation(
+ private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]])
+ extends LongHashedRelation with Externalizable {
+
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(null)
+
+ override def get(key: Long): Seq[InternalRow] = hashTable.get(key)
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ }
+}
+
+private[joins] final class UniqueLongHashedRelation(
+ private var hashTable: JavaHashMap[Long, UnsafeRow])
+ extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(null)
+
+ override def getValue(key: InternalRow): InternalRow = {
+ getValue(key.getLong(0))
+ }
+
+ override def getValue(key: Long): InternalRow = {
+ hashTable.get(key)
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ }
+}
+
+/**
+ * A relation that pack all the rows into a byte array, together with offsets and sizes.
+ *
+ * All the bytes of UnsafeRow are packed together as `bytes`:
+ *
+ * [ Row0 ][ Row1 ][] ... [ RowN ]
+ *
+ * With keys:
+ *
+ * start start+1 ... start+N
+ *
+ * `offsets` are offsets of UnsafeRows in the `bytes`
+ * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key.
+ *
+ * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as:
+ *
+ * start = 3
+ * offsets = [0, 0, 24]
+ * sizes = [24, 0, 32]
+ * bytes = [0 - 24][][24 - 56]
+ */
+private[joins] final class LongArrayRelation(
+ private var numFields: Int,
+ private var start: Long,
+ private var offsets: Array[Int],
+ private var sizes: Array[Int],
+ private var bytes: Array[Byte]
+ ) extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(0, 0L, null, null, null)
+
+ override def getValue(key: InternalRow): InternalRow = {
+ getValue(key.getLong(0))
+ }
+
+ override def getMemorySize: Long = {
+ offsets.length * 4 + sizes.length * 4 + bytes.length
+ }
+
+ override def getValue(key: Long): InternalRow = {
+ val idx = (key - start).toInt
+ if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
+ val result = new UnsafeRow(numFields)
+ result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
+ result
+ } else {
+ null
+ }
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ out.writeInt(numFields)
+ out.writeLong(start)
+ out.writeInt(sizes.length)
+ var i = 0
+ while (i < sizes.length) {
+ out.writeInt(sizes(i))
+ i += 1
+ }
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ numFields = in.readInt()
+ start = in.readLong()
+ val length = in.readInt()
+ // read sizes of rows
+ sizes = new Array[Int](length)
+ offsets = new Array[Int](length)
+ var i = 0
+ var offset = 0
+ while (i < length) {
+ offsets(i) = offset
+ sizes(i) = in.readInt()
+ offset += sizes(i)
+ i += 1
+ }
+ // read all the bytes
+ val total = in.readInt()
+ assert(total == offset)
+ bytes = new Array[Byte](total)
+ in.readFully(bytes)
+ }
+}
+
+/**
+ * Create hashed relation with key that is long.
+ */
+private[joins] object LongHashedRelation {
+
+ val DENSE_FACTOR = 0.2
+
+ def apply(
+ input: Iterator[InternalRow],
+ numInputRows: LongSQLMetric,
+ keyGenerator: Projection,
+ sizeEstimate: Int): HashedRelation = {
+
+ // Use a Java hash table here because unsafe maps expect fixed size records
+ val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
+
+ // Create a mapping of key -> rows
+ var numFields = 0
+ var keyIsUnique = true
+ var minKey = Long.MaxValue
+ var maxKey = Long.MinValue
+ while (input.hasNext) {
+ val unsafeRow = input.next().asInstanceOf[UnsafeRow]
+ numFields = unsafeRow.numFields()
+ numInputRows += 1
+ val rowKey = keyGenerator(unsafeRow)
+ if (!rowKey.anyNull) {
+ val key = rowKey.getLong(0)
+ minKey = math.min(minKey, key)
+ maxKey = math.max(maxKey, key)
+ val existingMatchList = hashTable.get(key)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new CompactBuffer[UnsafeRow]()
+ hashTable.put(key, newMatchList)
+ newMatchList
+ } else {
+ keyIsUnique = false
+ existingMatchList
+ }
+ matchList += unsafeRow.copy()
+ }
+ }
+
+ if (keyIsUnique) {
+ if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
+ // The keys are dense enough, so use LongArrayRelation
+ val length = (maxKey - minKey).toInt + 1
+ val sizes = new Array[Int](length)
+ val offsets = new Array[Int](length)
+ var offset = 0
+ var i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ offsets(i) = offset
+ sizes(i) = rows(0).getSizeInBytes
+ offset += sizes(i)
+ }
+ i += 1
+ }
+ val bytes = new Array[Byte](offset)
+ i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
+ }
+ i += 1
+ }
+ new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
+
+ } else {
+ // all the keys are unique, one row per key.
+ val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size)
+ val iter = hashTable.entrySet().iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ uniqHashTable.put(entry.getKey, entry.getValue()(0))
+ }
+ new UniqueLongHashedRelation(uniqHashTable)
+ }
+ } else {
+ new GeneralLongHashedRelation(hashTable)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 33d4976403..f015d29704 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -22,6 +22,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.IntegerType
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
@@ -122,10 +123,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("broadcast hash join") {
- val N = 20 << 20
+ val N = 100 << 20
val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
- runBenchmark("BroadcastHashJoin", N) {
+ runBenchmark("Join w long", N) {
sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count()
}
@@ -133,9 +134,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- BroadcastHashJoin codegen=false 4405 / 6147 4.0 250.0 1.0X
- BroadcastHashJoin codegen=true 1857 / 1878 11.0 90.9 2.4X
+ Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X
+ Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X
+ */
+
+ val dim2 = broadcast(sqlContext.range(1 << 16)
+ .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v"))
+
+ runBenchmark("Join w 2 ints", N) {
+ sqlContext.range(N).join(dim2,
+ (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1")
+ && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X
+ Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X
*/
+
}
ignore("hash and BytesToBytesMap") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index e5fd9e277f..f985dfbd8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.collection.CompactBuffer
-
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
// Key is simply the record itself
@@ -134,4 +133,32 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
out2.flush()
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
+
+ test("LongArrayRelation") {
+ val unsafeProj = UnsafeProjection.create(
+ Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
+ val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
+ val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false)))
+ val longRelation = LongHashedRelation(rows.iterator, SQLMetrics.nullLongMetric, keyProj, 100)
+ assert(longRelation.isInstanceOf[LongArrayRelation])
+ val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation]
+ (0 until 100).foreach { i =>
+ val row = longArrayRelation.getValue(i)
+ assert(row.getInt(0) === i)
+ assert(row.getInt(1) === i + 1)
+ }
+
+ val os = new ByteArrayOutputStream()
+ val out = new ObjectOutputStream(os)
+ longArrayRelation.writeExternal(out)
+ out.flush()
+ val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
+ val relation = new LongArrayRelation()
+ relation.readExternal(in)
+ (0 until 100).foreach { i =>
+ val row = longArrayRelation.getValue(i)
+ assert(row.getInt(0) === i)
+ assert(row.getInt(1) === i + 1)
+ }
+ }
}