aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-04 10:01:24 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-04 10:01:24 -0700
commit745425332f41e2ae94649f9d1ad675243f36f743 (patch)
tree78f29665e7d8dc7bb8cb9c7cfb4ec9ef5cce15c3 /sql
parent0340b3d279de6be4903673bbf3e6a1a2653de6c0 (diff)
downloadspark-745425332f41e2ae94649f9d1ad675243f36f743.tar.gz
spark-745425332f41e2ae94649f9d1ad675243f36f743.tar.bz2
spark-745425332f41e2ae94649f9d1ad675243f36f743.zip
[SPARK-14137] [SQL] Cleanup hash join
## What changes were proposed in this pull request? This PR did a few cleanup on HashedRelation and HashJoin: 1) Merge HashedRelation and UniqueHashedRelation together 2) Return an iterator from HashedRelation, so we donot need a create many UnsafeRow objects. 3) Return a copy of HashedRelation for thread-safety in BroadcastJoin, so we can re-use the UnafeRow objects. 4) Cleanup HashJoin, share most of the code between BroadcastHashJoin and ShuffleHashJoin 5) Removed UniqueLongHashedRelation, which will be replaced by LongUnsafeMap (another PR). 6) Update benchmark, before this patch, the selectivity of joins are too high. ## How was this patch tested? Existing tests. Author: Davies Liu <davies@databricks.com> Closes #12102 from davies/cleanup_hash.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala78
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala217
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala264
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala64
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala14
6 files changed, 268 insertions, 401 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 41e566c27b..67ac9e94ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -68,37 +68,9 @@ case class BroadcastHashJoin(
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
- val joinedRow = new JoinedRow()
- val hashTable = broadcastRelation.value
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
-
- joinType match {
- case Inner =>
- hashJoin(streamedIter, hashTable, numOutputRows)
-
- case LeftOuter =>
- streamedIter.flatMap { currentRow =>
- val rowKey = keyGenerator(currentRow)
- joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
- }
-
- case RightOuter =>
- streamedIter.flatMap { currentRow =>
- val rowKey = keyGenerator(currentRow)
- joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
- }
-
- case LeftSemi =>
- hashSemiJoin(streamedIter, hashTable, numOutputRows)
-
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashJoin should not take $x as the JoinType")
- }
+ val hashed = broadcastRelation.value.asReadOnlyCopy()
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize)
+ join(streamedIter, hashed, numOutputRows)
}
}
@@ -132,7 +104,7 @@ case class BroadcastHashJoin(
val clsName = broadcastRelation.value.getClass.getName
ctx.addMutableState(clsName, relationTerm,
s"""
- | $relationTerm = ($clsName) $broadcast.value();
+ | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
| incPeakExecutionMemory($relationTerm.getMemorySize());
""".stripMargin)
(broadcastRelation, relationTerm)
@@ -217,7 +189,7 @@ case class BroadcastHashJoin(
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -232,18 +204,15 @@ case class BroadcastHashJoin(
} else {
ctx.copyResult = true
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|if ($matches == null) continue;
- |int $size = $matches.size();
- |for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ |while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
| $checkCondition
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
@@ -287,7 +256,7 @@ case class BroadcastHashJoin(
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -306,22 +275,21 @@ case class BroadcastHashJoin(
} else {
ctx.copyResult = true
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
val i = ctx.freshName("i")
- val size = ctx.freshName("size")
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
- |int $size = $matches != null ? $matches.size() : 0;
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|boolean $found = false;
|// the last iteration of this loop is to emit an empty row if there is no matched rows.
- |for (int $i = 0; $i <= $size; $i++) {
- | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
| ${checkCondition.trim}
- | if (!$conditionPassed || ($i == $size && $found)) continue;
+ | if (!$conditionPassed) continue;
| $found = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
@@ -356,7 +324,7 @@ case class BroadcastHashJoin(
""
}
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -369,23 +337,19 @@ case class BroadcastHashJoin(
""".stripMargin
} else {
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|if ($matches == null) continue;
- |int $size = $matches.size();
|boolean $found = false;
- |for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ |while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
| $checkCondition
| $found = true;
- | break;
|}
|if (!$found) continue;
|$numOutput.add(1);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index c298b7dee0..b7c0f3e7d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.joins
import java.util.NoSuchElementException
+import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
import org.apache.spark.util.collection.CompactBuffer
@@ -110,169 +111,113 @@ trait HashJoin {
sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
}
- protected def buildSideKeyGenerator: Projection =
+ protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
- protected def streamSideKeyGenerator: Projection =
+ protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
- newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+ newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
} else {
(r: InternalRow) => true
}
- protected def createResultProjection: (InternalRow) => InternalRow =
- UnsafeProjection.create(self.schema)
-
- protected def hashJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- new Iterator[InternalRow] {
- private[this] var currentStreamedRow: InternalRow = _
- private[this] var currentHashMatches: Seq[InternalRow] = _
- private[this] var currentMatchPosition: Int = -1
-
- // Mutable per row objects.
- private[this] val joinRow = new JoinedRow
- private[this] val resultProjection = createResultProjection
-
- private[this] val joinKeys = streamSideKeyGenerator
-
- override final def hasNext: Boolean = {
- while (true) {
- // check if it's end of current matches
- if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
- currentHashMatches = null
- currentMatchPosition = -1
- }
-
- // find the next match
- while (currentHashMatches == null && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- val key = joinKeys(currentStreamedRow)
- if (!key.anyNull) {
- currentHashMatches = hashedRelation.get(key)
- if (currentHashMatches != null) {
- currentMatchPosition = 0
- }
- }
- }
- if (currentHashMatches == null) {
- return false
- }
-
- // found some matches
- buildSide match {
- case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
- case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
- }
- if (boundCondition(joinRow)) {
- return true
- } else {
- currentMatchPosition += 1
- }
- }
- false // unreachable
- }
-
- override final def next(): InternalRow = {
- // next() could be called without calling hasNext()
- if (hasNext) {
- currentMatchPosition += 1
- numOutputRows += 1
- resultProjection(joinRow)
- } else {
- throw new NoSuchElementException
- }
- }
+ protected def createResultProjection(): (InternalRow) => InternalRow = {
+ if (joinType == LeftSemi) {
+ UnsafeProjection.create(output, output)
+ } else {
+ // Always put the stream side on left to simplify implementation
+ // both of left and right side could be null
+ UnsafeProjection.create(
+ output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true)))
}
}
- @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
-
- @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
- @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
-
- protected[this] def leftOuterIterator(
- key: InternalRow,
- joinedRow: JoinedRow,
- rightIter: Iterable[InternalRow],
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (rightIter != null) {
- rightIter.collect {
- case r if boundCondition(joinedRow.withRight(r)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
- }
- }
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
- } else {
- temp
- }
+ private def innerJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinRow = new JoinedRow
+ val joinKeys = streamSideKeyGenerator()
+ streamIter.flatMap { srow =>
+ joinRow.withLeft(srow)
+ val matches = hashedRelation.get(joinKeys(srow))
+ if (matches != null) {
+ matches.map(joinRow.withRight(_)).filter(boundCondition)
} else {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+ Seq.empty
}
}
- ret.iterator
}
- protected[this] def rightOuterIterator(
- key: InternalRow,
- leftIter: Iterable[InternalRow],
- joinedRow: JoinedRow,
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (leftIter != null) {
- leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
+ private def outerJoin(
+ streamedIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinedRow = new JoinedRow()
+ val keyGenerator = streamSideKeyGenerator()
+ val nullRow = new GenericInternalRow(buildPlan.output.length)
+
+ streamedIter.flatMap { currentRow =>
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withLeft(currentRow)
+ val buildIter = hashedRelation.get(rowKey)
+ new RowIterator {
+ private var found = false
+ override def advanceNext(): Boolean = {
+ while (buildIter != null && buildIter.hasNext) {
+ val nextBuildRow = buildIter.next()
+ if (boundCondition(joinedRow.withRight(nextBuildRow))) {
+ found = true
+ return true
}
}
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- } else {
- temp
+ if (!found) {
+ joinedRow.withRight(nullRow)
+ found = true
+ return true
+ }
+ false
}
- } else {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- }
+ override def getRow: InternalRow = joinedRow
+ }.toScala
}
- ret.iterator
}
- protected def hashSemiJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val joinKeys = streamSideKeyGenerator
+ private def semiJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter { current =>
val key = joinKeys(current)
- lazy val rowBuffer = hashedRelation.get(key)
- val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
+ lazy val buildIter = hashedRelation.get(key)
+ !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
})
- if (r) numOutputRows += 1
- r
+ }
+ }
+
+ protected def join(
+ streamedIter: Iterator[InternalRow],
+ hashed: HashedRelation,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+
+ val joinedIter = joinType match {
+ case Inner =>
+ innerJoin(streamedIter, hashed)
+ case LeftOuter | RightOuter =>
+ outerJoin(streamedIter, hashed)
+ case LeftSemi =>
+ semiJoin(streamedIter, hashed)
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastHashJoin should not take $x as the JoinType")
+ }
+
+ val resultProj = createResultProjection
+ joinedIter.map { r =>
+ numOutputRows += 1
+ resultProj(r)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 91c470d187..5ccb435686 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.{KnownSizeEstimation, Utils}
import org.apache.spark.util.collection.CompactBuffer
@@ -39,17 +38,43 @@ import org.apache.spark.util.collection.CompactBuffer
private[execution] sealed trait HashedRelation {
/**
* Returns matched rows.
+ *
+ * Returns null if there is no matched rows.
*/
- def get(key: InternalRow): Seq[InternalRow]
+ def get(key: InternalRow): Iterator[InternalRow]
/**
* Returns matched rows for a key that has only one column with LongType.
+ *
+ * Returns null if there is no matched rows.
+ */
+ def get(key: Long): Iterator[InternalRow] = {
+ throw new UnsupportedOperationException
+ }
+
+ /**
+ * Returns the matched single row.
*/
- def get(key: Long): Seq[InternalRow] = {
+ def getValue(key: InternalRow): InternalRow
+
+ /**
+ * Returns the matched single row with key that have only one column of LongType.
+ */
+ def getValue(key: Long): InternalRow = {
throw new UnsupportedOperationException
}
/**
+ * Returns true iff all the keys are unique.
+ */
+ def keyIsUnique: Boolean
+
+ /**
+ * Returns a read-only copy of this, to be safely used in current thread.
+ */
+ def asReadOnlyCopy(): HashedRelation
+
+ /**
* Returns the size of used memory.
*/
def getMemorySize: Long = 1L // to make the test happy
@@ -76,44 +101,6 @@ private[execution] sealed trait HashedRelation {
}
}
-/**
- * Interface for a hashed relation that have only one row per key.
- *
- * We should call getValue() for better performance.
- */
-private[execution] trait UniqueHashedRelation extends HashedRelation {
-
- /**
- * Returns the matched single row.
- */
- def getValue(key: InternalRow): InternalRow
-
- /**
- * Returns the matched single row with key that have only one column of LongType.
- */
- def getValue(key: Long): InternalRow = {
- throw new UnsupportedOperationException
- }
-
- override def get(key: InternalRow): Seq[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- CompactBuffer[InternalRow](row)
- } else {
- null
- }
- }
-
- override def get(key: Long): Seq[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- CompactBuffer[InternalRow](row)
- } else {
- null
- }
- }
-}
-
private[execution] object HashedRelation {
/**
@@ -150,6 +137,11 @@ private[joins] class UnsafeHashedRelation(
private[joins] def this() = this(0, null) // Needed for serialization
+ override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues()
+
+ override def asReadOnlyCopy(): UnsafeHashedRelation =
+ new UnsafeHashedRelation(numFields, binaryMap)
+
override def getMemorySize: Long = {
binaryMap.getTotalMemoryConsumption
}
@@ -158,23 +150,39 @@ private[joins] class UnsafeHashedRelation(
binaryMap.getTotalMemoryConsumption
}
- override def get(key: InternalRow): Seq[InternalRow] = {
+ // re-used in get()/getValue()
+ var resultRow = new UnsafeRow(numFields)
+
+ override def get(key: InternalRow): Iterator[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
val map = binaryMap // avoid the compiler error
val loc = new map.Location // this could be allocated in stack
binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
if (loc.isDefined) {
- val buffer = CompactBuffer[UnsafeRow]()
- val row = new UnsafeRow(numFields)
- row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
- buffer += row
- while (loc.nextValue()) {
- val row = new UnsafeRow(numFields)
- row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
- buffer += row
+ new Iterator[UnsafeRow] {
+ private var _hasNext = true
+ override def hasNext: Boolean = _hasNext
+ override def next(): UnsafeRow = {
+ resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ _hasNext = loc.nextValue()
+ resultRow
+ }
}
- buffer
+ } else {
+ null
+ }
+ }
+
+ def getValue(key: InternalRow): InternalRow = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ resultRow
} else {
null
}
@@ -212,6 +220,7 @@ private[joins] class UnsafeHashedRelation(
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
numFields = in.readInt()
+ resultRow = new UnsafeRow(numFields)
val nKeys = in.readInt()
val nValues = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
@@ -263,29 +272,6 @@ private[joins] class UnsafeHashedRelation(
}
}
-/**
- * A HashedRelation for UnsafeRow with unique keys.
- */
-private[joins] final class UniqueUnsafeHashedRelation(
- private var numFields: Int,
- private var binaryMap: BytesToBytesMap)
- extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation {
- def getValue(key: InternalRow): InternalRow = {
- val unsafeKey = key.asInstanceOf[UnsafeRow]
- val map = binaryMap // avoid the compiler error
- val loc = new map.Location // this could be allocated in stack
- binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
- if (loc.isDefined) {
- val row = new UnsafeRow(numFields)
- row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
- row
- } else {
- null
- }
- }
-}
-
private[joins] object UnsafeHashedRelation {
def apply(
@@ -315,17 +301,12 @@ private[joins] object UnsafeHashedRelation {
// Create a mapping of buildKeys -> rows
var numFields = 0
- // Whether all the keys are unique or not
- var allUnique: Boolean = true
while (input.hasNext) {
val row = input.next().asInstanceOf[UnsafeRow]
numFields = row.numFields()
val key = keyGenerator(row)
if (!key.anyNull) {
val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
- if (loc.isDefined) {
- allUnique = false
- }
val success = loc.append(
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
@@ -336,11 +317,7 @@ private[joins] object UnsafeHashedRelation {
}
}
- if (allUnique) {
- new UniqueUnsafeHashedRelation(numFields, binaryMap)
- } else {
- new UnsafeHashedRelation(numFields, binaryMap)
- }
+ new UnsafeHashedRelation(numFields, binaryMap)
}
}
@@ -348,9 +325,12 @@ private[joins] object UnsafeHashedRelation {
* An interface for a hashed relation that the key is a Long.
*/
private[joins] trait LongHashedRelation extends HashedRelation {
- override def get(key: InternalRow): Seq[InternalRow] = {
+ override def get(key: InternalRow): Iterator[InternalRow] = {
get(key.getLong(0))
}
+ override def getValue(key: InternalRow): InternalRow = {
+ getValue(key.getLong(0))
+ }
}
private[joins] final class GeneralLongHashedRelation(
@@ -360,30 +340,18 @@ private[joins] final class GeneralLongHashedRelation(
// Needed for serialization (it is public to make Java serialization work)
def this() = this(null)
- override def get(key: Long): Seq[InternalRow] = hashTable.get(key)
+ override def keyIsUnique: Boolean = false
- override def writeExternal(out: ObjectOutput): Unit = {
- writeBytes(out, SparkSqlSerializer.serialize(hashTable))
- }
-
- override def readExternal(in: ObjectInput): Unit = {
- hashTable = SparkSqlSerializer.deserialize(readBytes(in))
- }
-}
+ override def asReadOnlyCopy(): GeneralLongHashedRelation =
+ new GeneralLongHashedRelation(hashTable)
-private[joins] final class UniqueLongHashedRelation(
- private var hashTable: JavaHashMap[Long, UnsafeRow])
- extends UniqueHashedRelation with LongHashedRelation with Externalizable {
-
- // Needed for serialization (it is public to make Java serialization work)
- def this() = this(null)
-
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
- }
-
- override def getValue(key: Long): InternalRow = {
- hashTable.get(key)
+ override def get(key: Long): Iterator[InternalRow] = {
+ val rows = hashTable.get(key)
+ if (rows != null) {
+ rows.toIterator
+ } else {
+ null
+ }
}
override def writeExternal(out: ObjectOutput): Unit = {
@@ -422,25 +390,36 @@ private[joins] final class LongArrayRelation(
private var offsets: Array[Int],
private var sizes: Array[Int],
private var bytes: Array[Byte]
- ) extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+ ) extends LongHashedRelation with Externalizable {
// Needed for serialization (it is public to make Java serialization work)
def this() = this(0, 0L, null, null, null)
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
+ override def keyIsUnique: Boolean = true
+
+ override def asReadOnlyCopy(): LongArrayRelation = {
+ new LongArrayRelation(numFields, start, offsets, sizes, bytes)
}
override def getMemorySize: Long = {
offsets.length * 4 + sizes.length * 4 + bytes.length
}
+ override def get(key: Long): Iterator[InternalRow] = {
+ val row = getValue(key)
+ if (row != null) {
+ Seq(row).toIterator
+ } else {
+ null
+ }
+ }
+
+ var resultRow = new UnsafeRow(numFields)
override def getValue(key: Long): InternalRow = {
val idx = (key - start).toInt
if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
- val result = new UnsafeRow(numFields)
- result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
- result
+ resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
+ resultRow
} else {
null
}
@@ -461,6 +440,7 @@ private[joins] final class LongArrayRelation(
override def readExternal(in: ObjectInput): Unit = {
numFields = in.readInt()
+ resultRow = new UnsafeRow(numFields)
start = in.readLong()
val length = in.readInt()
// read sizes of rows
@@ -523,44 +503,32 @@ private[joins] object LongHashedRelation {
}
}
- if (keyIsUnique) {
- if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
- // The keys are dense enough, so use LongArrayRelation
- val length = (maxKey - minKey).toInt + 1
- val sizes = new Array[Int](length)
- val offsets = new Array[Int](length)
- var offset = 0
- var i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- offsets(i) = offset
- sizes(i) = rows(0).getSizeInBytes
- offset += sizes(i)
- }
- i += 1
- }
- val bytes = new Array[Byte](offset)
- i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
- }
- i += 1
+ if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
+ // The keys are dense enough, so use LongArrayRelation
+ val length = (maxKey - minKey).toInt + 1
+ val sizes = new Array[Int](length)
+ val offsets = new Array[Int](length)
+ var offset = 0
+ var i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ offsets(i) = offset
+ sizes(i) = rows(0).getSizeInBytes
+ offset += sizes(i)
}
- new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
-
- } else {
- // all the keys are unique, one row per key.
- val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size)
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- uniqHashTable.put(entry.getKey, entry.getValue()(0))
+ i += 1
+ }
+ val bytes = new Array[Byte](offset)
+ i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
}
- new UniqueLongHashedRelation(uniqHashTable)
+ i += 1
}
+ new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
} else {
new GeneralLongHashedRelation(hashTable)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index e3a2eaea5d..c63faacf33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -102,39 +102,9 @@ case class ShuffledHashJoin(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
-
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
- val joinedRow = new JoinedRow
- joinType match {
- case Inner =>
- hashJoin(streamIter, hashed, numOutputRows)
-
- case LeftSemi =>
- hashSemiJoin(streamIter, hashed, numOutputRows)
-
- case LeftOuter =>
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
- streamIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows)
- })
-
- case RightOuter =>
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
- streamIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows)
- })
-
- case x =>
- throw new IllegalArgumentException(
- s"ShuffledHashJoin should not take $x as the JoinType")
- }
+ join(streamIter, hashed, numOutputRows)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 289e1b6db9..3566ef3043 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -166,20 +166,35 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("broadcast hash join") {
- val N = 100 << 20
+ val N = 20 << 20
val M = 1 << 16
val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("Join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count()
}
/*
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X
- Join w long codegen=true 735 / 853 142.7 7.0 7.8X
+ Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X
+ Join w long codegen=true 275 / 352 76.2 13.1 19.4X
+ */
+
+ runBenchmark("Join w long duplicated", N) {
+ val dim = broadcast(sqlContext.range(M).selectExpr("cast(id/10 as long) as k"))
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count()
+ }
+
+ /**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X
+ Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X
*/
val dim2 = broadcast(sqlContext.range(M)
@@ -187,16 +202,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
runBenchmark("Join w 2 ints", N) {
sqlContext.range(N).join(dim2,
- (col("id") bitwiseAND M).cast(IntegerType) === col("k1")
- && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count()
+ (col("id") % M).cast(IntegerType) === col("k1")
+ && (col("id") % M).cast(IntegerType) === col("k2")).count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X
- Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X
+ Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X
+ Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X
*/
val dim3 = broadcast(sqlContext.range(M)
@@ -204,16 +220,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
runBenchmark("Join w 2 longs", N) {
sqlContext.range(N).join(dim3,
- (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+ (col("id") % M) === col("k1") && (col("id") % M) === col("k2"))
.count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 longs codegen=false 12725 / 13158 8.2 121.4 1.0X
- Join w 2 longs codegen=true 6044 / 6771 17.3 57.6 2.1X
+ Join w 2 longs codegen=false 5905 / 6123 3.6 281.6 1.0X
+ Join w 2 longs codegen=true 2230 / 2529 9.4 106.3 2.6X
*/
val dim4 = broadcast(sqlContext.range(M)
@@ -227,34 +244,36 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ Join w 2 longs duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 longs duplicated codegen=false 13066 / 13710 8.0 124.6 1.0X
- Join w 2 longs duplicated codegen=true 7122 / 7277 14.7 67.9 1.8X
+ Join w 2 longs duplicated codegen=false 6420 / 6587 3.3 306.1 1.0X
+ Join w 2 longs duplicated codegen=true 2080 / 2139 10.1 99.2 3.1X
*/
runBenchmark("outer join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "left").count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X
- outer join w long codegen=true 769 / 796 136.3 7.3 19.9X
+ outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X
+ outer join w long codegen=true 216 / 226 97.2 10.3 26.3X
*/
runBenchmark("semi join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X
- semi join w long codegen=true 814 / 934 128.8 7.8 7.1X
+ semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X
+ semi join w long codegen=true 211 / 229 99.2 10.1 22.2X
*/
}
@@ -303,11 +322,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X
- shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X
+ shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X
+ shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X
*/
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index ed4cc1c4c4..ed87a99439 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -34,20 +34,20 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val toUnsafe = UnsafeProjection.create(schema)
- val unsafeData = data.map(toUnsafe(_).copy()).toArray
+ val unsafeData = data.map(toUnsafe(_).copy())
val buildKey = Seq(BoundReference(0, IntegerType, false))
val keyGenerator = UnsafeProjection.create(buildKey)
val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
assert(hashed.isInstanceOf[UnsafeHashedRelation])
- assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
- assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0)))
+ assert(hashed.get(unsafeData(1)).toArray === Array(unsafeData(1)))
assert(hashed.get(toUnsafe(InternalRow(10))) === null)
val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
data2 += unsafeData(2).copy()
- assert(hashed.get(unsafeData(2)) === data2)
+ assert(hashed.get(unsafeData(2)).toArray === data2.toArray)
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
@@ -56,10 +56,10 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val hashed2 = new UnsafeHashedRelation()
hashed2.readExternal(in)
- assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
- assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed2.get(unsafeData(0)).toArray === Array(unsafeData(0)))
+ assert(hashed2.get(unsafeData(1)).toArray === Array(unsafeData(1)))
assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
- assert(hashed2.get(unsafeData(2)) === data2)
+ assert(hashed2.get(unsafeData(2)).toArray === data2)
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)