aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-08 14:09:14 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-08 14:09:14 -0800
commitff0af0ddfa4d198b203c3a39f8532cfbd4f4e027 (patch)
treebed882aeeb85eeb67562b1d2c58390d257896bca
parent37bc203c8dd5022cb11d53b697c28a737ee85bcc (diff)
downloadspark-ff0af0ddfa4d198b203c3a39f8532cfbd4f4e027.tar.gz
spark-ff0af0ddfa4d198b203c3a39f8532cfbd4f4e027.tar.bz2
spark-ff0af0ddfa4d198b203c3a39f8532cfbd4f4e027.zip
[SPARK-13095] [SQL] improve performance for broadcast join with dimension table
This PR improve the performance for Broadcast join with dimension tables, which is common in data warehouse. If the join key can fit in a long, we will use a special api `get(Long)` to get the rows from HashedRelation. If the HashedRelation only have unique keys, we will use a special api `getValue(Long)` or `getValue(InternalRow)`. If the keys can fit within a long, also the keys are dense, we will use a array of UnsafeRow, instead a hash map. TODO: will do cleanup Author: Davies Liu <davies@databricks.com> Closes #11065 from davies/gen_dim.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala96
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala297
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala29
8 files changed, 438 insertions, 69 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 131efea20f..4ca2d85406 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -38,6 +38,7 @@ trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: TungstenAggregate => "agg"
+ case _: BroadcastHashJoin => "bhj"
case _ => nodeName.toLowerCase
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 943ad31c0c..cbd549763a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -90,8 +90,14 @@ case class BroadcastHashJoin(
// The following line doesn't run in a job so we cannot track the metric value. However, we
// have already tracked it in the above lines. So here we can use
// `SQLMetrics.nullLongMetric` to ignore it.
- val hashed = HashedRelation(
- input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+ // TODO: move this check into HashedRelation
+ val hashed = if (canJoinKeyFitWithinLong) {
+ LongHashedRelation(
+ input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+ } else {
+ HashedRelation(
+ input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+ }
sparkContext.broadcast(hashed)
}
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
@@ -112,15 +118,12 @@ case class BroadcastHashJoin(
streamedPlan.execute().mapPartitions { streamedIter =>
val hashedRelation = broadcastRelation.value
- hashedRelation match {
- case unsafe: UnsafeHashedRelation =>
- TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
- case _ =>
- }
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
}
}
+ private var broadcastRelation: Broadcast[HashedRelation] = _
// the term for hash relation
private var relationTerm: String = _
@@ -129,16 +132,15 @@ case class BroadcastHashJoin(
}
override def doProduce(ctx: CodegenContext): String = {
- // create a name for HashRelation
- val broadcastRelation = Await.result(broadcastFuture, timeout)
+ // create a name for HashedRelation
+ broadcastRelation = Await.result(broadcastFuture, timeout)
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
relationTerm = ctx.freshName("relation")
- // TODO: create specialized HashRelation for single join key
- val clsName = classOf[UnsafeHashedRelation].getName
+ val clsName = broadcastRelation.value.getClass.getName
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = ($clsName) $broadcast.value();
- | incPeakExecutionMemory($relationTerm.getUnsafeSize());
+ | incPeakExecutionMemory($relationTerm.getMemorySize());
""".stripMargin)
s"""
@@ -147,23 +149,24 @@ case class BroadcastHashJoin(
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- // generate the key as UnsafeRow
+ // generate the key as UnsafeRow or Long
ctx.currentVars = input
- val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
- val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
- val keyTerm = keyVal.value
- val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"
+ val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
+ val expr = rewriteKeyExpr(streamedKeys).head
+ val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
+ (ev, ev.isNull)
+ } else {
+ val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
+ val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+ (ev, s"${ev.value}.anyNull()")
+ }
// find the matches from HashedRelation
- val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
- val row = ctx.freshName("row")
+ val matched = ctx.freshName("matched")
// create variables for output
ctx.currentVars = null
- ctx.INPUT_ROW = row
+ ctx.INPUT_ROW = matched
val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).gen(ctx)
}
@@ -172,7 +175,7 @@ case class BroadcastHashJoin(
case BuildRight => input ++ buildColumns
}
- val ouputCode = if (condition.isDefined) {
+ val outputCode = if (condition.isDefined) {
// filter the output via condition
ctx.currentVars = resultVars
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
@@ -186,20 +189,39 @@ case class BroadcastHashJoin(
consume(ctx, resultVars)
}
- s"""
- | // generate join key
- | ${keyVal.code}
- | // find matches from HashRelation
- | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
- | if ($matches != null) {
- | int $size = $matches.size();
- | for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $row = (UnsafeRow) $matches.apply($i);
- | ${buildColumns.map(_.code).mkString("\n")}
- | $ouputCode
- | }
- | }
+ if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ s"""
+ | // generate join key
+ | ${keyVal.code}
+ | // find matches from HashedRelation
+ | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
+ | if ($matched != null) {
+ | ${buildColumns.map(_.code).mkString("\n")}
+ | $outputCode
+ | }
""".stripMargin
+
+ } else {
+ val matches = ctx.freshName("matches")
+ val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val i = ctx.freshName("i")
+ val size = ctx.freshName("size")
+ s"""
+ | // generate join key
+ | ${keyVal.code}
+ | // find matches from HashRelation
+ | $bufferType $matches = ${anyNull} ? null :
+ | ($bufferType) $relationTerm.get(${keyVal.value});
+ | if ($matches != null) {
+ | int $size = $matches.size();
+ | for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ | ${buildColumns.map(_.code).mkString("\n")}
+ | $outputCode
+ | }
+ | }
+ """.stripMargin
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index f48fc3b848..ad3275696e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -116,12 +116,7 @@ case class BroadcastHashOuterJoin(
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
val keyGenerator = streamedKeyGenerator
-
- hashTable match {
- case unsafe: UnsafeHashedRelation =>
- TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
- case _ =>
- }
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
val resultProj = resultProjection
joinType match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 8929dc3af1..d0e18dfcf3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -64,11 +64,7 @@ case class BroadcastLeftSemiJoinHash(
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
- hashedRelation match {
- case unsafe: UnsafeHashedRelation =>
- TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
- case _ =>
- }
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 8ef854001f..ecbb1ac64b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.LongSQLMetric
-
+import org.apache.spark.sql.types.{IntegralType, LongType}
trait HashJoin {
self: SparkPlan =>
@@ -47,11 +47,49 @@ trait HashJoin {
override def output: Seq[Attribute] = left.output ++ right.output
+ /**
+ * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
+ *
+ * If not, returns the original expressions.
+ */
+ def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
+ var keyExpr: Expression = null
+ var width = 0
+ keys.foreach { e =>
+ e.dataType match {
+ case dt: IntegralType if dt.defaultSize <= 8 - width =>
+ if (width == 0) {
+ if (e.dataType != LongType) {
+ keyExpr = Cast(e, LongType)
+ } else {
+ keyExpr = e
+ }
+ width = dt.defaultSize
+ } else {
+ val bits = dt.defaultSize * 8
+ keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
+ BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+ width -= bits
+ }
+ // TODO: support BooleanType, DateType and TimestampType
+ case other =>
+ return keys
+ }
+ }
+ keyExpr :: Nil
+ }
+
+ protected val canJoinKeyFitWithinLong: Boolean = {
+ val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
+ val key = rewriteKeyExpr(buildKeys)
+ sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
+ }
+
protected def buildSideKeyGenerator: Projection =
- UnsafeProjection.create(buildKeys, buildPlan.output)
+ UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
protected def streamSideKeyGenerator: Projection =
- UnsafeProjection.create(streamedKeys, streamedPlan.output)
+ UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index ee7a1bdc34..c94d6c195b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -39,8 +39,23 @@ import org.apache.spark.util.collection.CompactBuffer
* object.
*/
private[execution] sealed trait HashedRelation {
+ /**
+ * Returns matched rows.
+ */
def get(key: InternalRow): Seq[InternalRow]
+ /**
+ * Returns matched rows for a key that has only one column with LongType.
+ */
+ def get(key: Long): Seq[InternalRow] = {
+ throw new UnsupportedOperationException
+ }
+
+ /**
+ * Returns the size of used memory.
+ */
+ def getMemorySize: Long = 1L // to make the test happy
+
// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
@@ -58,11 +73,48 @@ private[execution] sealed trait HashedRelation {
}
}
+/**
+ * Interface for a hashed relation that have only one row per key.
+ *
+ * We should call getValue() for better performance.
+ */
+private[execution] trait UniqueHashedRelation extends HashedRelation {
+
+ /**
+ * Returns the matched single row.
+ */
+ def getValue(key: InternalRow): InternalRow
+
+ /**
+ * Returns the matched single row with key that have only one column of LongType.
+ */
+ def getValue(key: Long): InternalRow = {
+ throw new UnsupportedOperationException
+ }
+
+ override def get(key: InternalRow): Seq[InternalRow] = {
+ val row = getValue(key)
+ if (row != null) {
+ CompactBuffer[InternalRow](row)
+ } else {
+ null
+ }
+ }
+
+ override def get(key: Long): Seq[InternalRow] = {
+ val row = getValue(key)
+ if (row != null) {
+ CompactBuffer[InternalRow](row)
+ } else {
+ null
+ }
+ }
+}
/**
* A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
*/
-private[joins] final class GeneralHashedRelation(
+private[joins] class GeneralHashedRelation(
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {
@@ -85,19 +137,14 @@ private[joins] final class GeneralHashedRelation(
* A specialized [[HashedRelation]] that maps key into a single value. This implementation
* assumes the key is unique.
*/
-private[joins]
-final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
- extends HashedRelation with Externalizable {
+private[joins] class UniqueKeyHashedRelation(
+ private var hashTable: JavaHashMap[InternalRow, InternalRow])
+ extends UniqueHashedRelation with Externalizable {
// Needed for serialization (it is public to make Java serialization work)
def this() = this(null)
- override def get(key: InternalRow): Seq[InternalRow] = {
- val v = hashTable.get(key)
- if (v eq null) null else CompactBuffer(v)
- }
-
- def getValue(key: InternalRow): InternalRow = hashTable.get(key)
+ override def getValue(key: InternalRow): InternalRow = hashTable.get(key)
override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
@@ -108,8 +155,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
}
}
-// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
-
private[execution] object HashedRelation {
@@ -208,7 +253,7 @@ private[joins] final class UnsafeHashedRelation(
*
* For non-broadcast joins or in local mode, return 0.
*/
- def getUnsafeSize: Long = {
+ override def getMemorySize: Long = {
if (binaryMap != null) {
binaryMap.getTotalMemoryConsumption
} else {
@@ -408,6 +453,232 @@ private[joins] object UnsafeHashedRelation {
}
}
+ // TODO: create UniqueUnsafeRelation
new UnsafeHashedRelation(hashTable)
}
}
+
+/**
+ * An interface for a hashed relation that the key is a Long.
+ */
+private[joins] trait LongHashedRelation extends HashedRelation {
+ override def get(key: InternalRow): Seq[InternalRow] = {
+ get(key.getLong(0))
+ }
+}
+
+private[joins] final class GeneralLongHashedRelation(
+ private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]])
+ extends LongHashedRelation with Externalizable {
+
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(null)
+
+ override def get(key: Long): Seq[InternalRow] = hashTable.get(key)
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ }
+}
+
+private[joins] final class UniqueLongHashedRelation(
+ private var hashTable: JavaHashMap[Long, UnsafeRow])
+ extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(null)
+
+ override def getValue(key: InternalRow): InternalRow = {
+ getValue(key.getLong(0))
+ }
+
+ override def getValue(key: Long): InternalRow = {
+ hashTable.get(key)
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ }
+}
+
+/**
+ * A relation that pack all the rows into a byte array, together with offsets and sizes.
+ *
+ * All the bytes of UnsafeRow are packed together as `bytes`:
+ *
+ * [ Row0 ][ Row1 ][] ... [ RowN ]
+ *
+ * With keys:
+ *
+ * start start+1 ... start+N
+ *
+ * `offsets` are offsets of UnsafeRows in the `bytes`
+ * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key.
+ *
+ * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as:
+ *
+ * start = 3
+ * offsets = [0, 0, 24]
+ * sizes = [24, 0, 32]
+ * bytes = [0 - 24][][24 - 56]
+ */
+private[joins] final class LongArrayRelation(
+ private var numFields: Int,
+ private var start: Long,
+ private var offsets: Array[Int],
+ private var sizes: Array[Int],
+ private var bytes: Array[Byte]
+ ) extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(0, 0L, null, null, null)
+
+ override def getValue(key: InternalRow): InternalRow = {
+ getValue(key.getLong(0))
+ }
+
+ override def getMemorySize: Long = {
+ offsets.length * 4 + sizes.length * 4 + bytes.length
+ }
+
+ override def getValue(key: Long): InternalRow = {
+ val idx = (key - start).toInt
+ if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
+ val result = new UnsafeRow(numFields)
+ result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
+ result
+ } else {
+ null
+ }
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ out.writeInt(numFields)
+ out.writeLong(start)
+ out.writeInt(sizes.length)
+ var i = 0
+ while (i < sizes.length) {
+ out.writeInt(sizes(i))
+ i += 1
+ }
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ numFields = in.readInt()
+ start = in.readLong()
+ val length = in.readInt()
+ // read sizes of rows
+ sizes = new Array[Int](length)
+ offsets = new Array[Int](length)
+ var i = 0
+ var offset = 0
+ while (i < length) {
+ offsets(i) = offset
+ sizes(i) = in.readInt()
+ offset += sizes(i)
+ i += 1
+ }
+ // read all the bytes
+ val total = in.readInt()
+ assert(total == offset)
+ bytes = new Array[Byte](total)
+ in.readFully(bytes)
+ }
+}
+
+/**
+ * Create hashed relation with key that is long.
+ */
+private[joins] object LongHashedRelation {
+
+ val DENSE_FACTOR = 0.2
+
+ def apply(
+ input: Iterator[InternalRow],
+ numInputRows: LongSQLMetric,
+ keyGenerator: Projection,
+ sizeEstimate: Int): HashedRelation = {
+
+ // Use a Java hash table here because unsafe maps expect fixed size records
+ val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
+
+ // Create a mapping of key -> rows
+ var numFields = 0
+ var keyIsUnique = true
+ var minKey = Long.MaxValue
+ var maxKey = Long.MinValue
+ while (input.hasNext) {
+ val unsafeRow = input.next().asInstanceOf[UnsafeRow]
+ numFields = unsafeRow.numFields()
+ numInputRows += 1
+ val rowKey = keyGenerator(unsafeRow)
+ if (!rowKey.anyNull) {
+ val key = rowKey.getLong(0)
+ minKey = math.min(minKey, key)
+ maxKey = math.max(maxKey, key)
+ val existingMatchList = hashTable.get(key)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new CompactBuffer[UnsafeRow]()
+ hashTable.put(key, newMatchList)
+ newMatchList
+ } else {
+ keyIsUnique = false
+ existingMatchList
+ }
+ matchList += unsafeRow.copy()
+ }
+ }
+
+ if (keyIsUnique) {
+ if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
+ // The keys are dense enough, so use LongArrayRelation
+ val length = (maxKey - minKey).toInt + 1
+ val sizes = new Array[Int](length)
+ val offsets = new Array[Int](length)
+ var offset = 0
+ var i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ offsets(i) = offset
+ sizes(i) = rows(0).getSizeInBytes
+ offset += sizes(i)
+ }
+ i += 1
+ }
+ val bytes = new Array[Byte](offset)
+ i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
+ }
+ i += 1
+ }
+ new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
+
+ } else {
+ // all the keys are unique, one row per key.
+ val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size)
+ val iter = hashTable.entrySet().iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ uniqHashTable.put(entry.getKey, entry.getValue()(0))
+ }
+ new UniqueLongHashedRelation(uniqHashTable)
+ }
+ } else {
+ new GeneralLongHashedRelation(hashTable)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 33d4976403..f015d29704 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -22,6 +22,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.IntegerType
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
@@ -122,10 +123,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("broadcast hash join") {
- val N = 20 << 20
+ val N = 100 << 20
val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
- runBenchmark("BroadcastHashJoin", N) {
+ runBenchmark("Join w long", N) {
sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count()
}
@@ -133,9 +134,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- BroadcastHashJoin codegen=false 4405 / 6147 4.0 250.0 1.0X
- BroadcastHashJoin codegen=true 1857 / 1878 11.0 90.9 2.4X
+ Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X
+ Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X
+ */
+
+ val dim2 = broadcast(sqlContext.range(1 << 16)
+ .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v"))
+
+ runBenchmark("Join w 2 ints", N) {
+ sqlContext.range(N).join(dim2,
+ (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1")
+ && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X
+ Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X
*/
+
}
ignore("hash and BytesToBytesMap") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index e5fd9e277f..f985dfbd8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.collection.CompactBuffer
-
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
// Key is simply the record itself
@@ -134,4 +133,32 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
out2.flush()
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
+
+ test("LongArrayRelation") {
+ val unsafeProj = UnsafeProjection.create(
+ Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
+ val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
+ val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false)))
+ val longRelation = LongHashedRelation(rows.iterator, SQLMetrics.nullLongMetric, keyProj, 100)
+ assert(longRelation.isInstanceOf[LongArrayRelation])
+ val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation]
+ (0 until 100).foreach { i =>
+ val row = longArrayRelation.getValue(i)
+ assert(row.getInt(0) === i)
+ assert(row.getInt(1) === i + 1)
+ }
+
+ val os = new ByteArrayOutputStream()
+ val out = new ObjectOutputStream(os)
+ longArrayRelation.writeExternal(out)
+ out.flush()
+ val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
+ val relation = new LongArrayRelation()
+ relation.readExternal(in)
+ (0 until 100).foreach { i =>
+ val row = longArrayRelation.getValue(i)
+ assert(row.getInt(0) === i)
+ assert(row.getInt(1) === i + 1)
+ }
+ }
}