aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala78
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala217
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala264
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala64
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala14
6 files changed, 268 insertions, 401 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 41e566c27b..67ac9e94ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -68,37 +68,9 @@ case class BroadcastHashJoin(
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
- val joinedRow = new JoinedRow()
- val hashTable = broadcastRelation.value
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
-
- joinType match {
- case Inner =>
- hashJoin(streamedIter, hashTable, numOutputRows)
-
- case LeftOuter =>
- streamedIter.flatMap { currentRow =>
- val rowKey = keyGenerator(currentRow)
- joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
- }
-
- case RightOuter =>
- streamedIter.flatMap { currentRow =>
- val rowKey = keyGenerator(currentRow)
- joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
- }
-
- case LeftSemi =>
- hashSemiJoin(streamedIter, hashTable, numOutputRows)
-
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashJoin should not take $x as the JoinType")
- }
+ val hashed = broadcastRelation.value.asReadOnlyCopy()
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize)
+ join(streamedIter, hashed, numOutputRows)
}
}
@@ -132,7 +104,7 @@ case class BroadcastHashJoin(
val clsName = broadcastRelation.value.getClass.getName
ctx.addMutableState(clsName, relationTerm,
s"""
- | $relationTerm = ($clsName) $broadcast.value();
+ | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
| incPeakExecutionMemory($relationTerm.getMemorySize());
""".stripMargin)
(broadcastRelation, relationTerm)
@@ -217,7 +189,7 @@ case class BroadcastHashJoin(
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -232,18 +204,15 @@ case class BroadcastHashJoin(
} else {
ctx.copyResult = true
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|if ($matches == null) continue;
- |int $size = $matches.size();
- |for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ |while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
| $checkCondition
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
@@ -287,7 +256,7 @@ case class BroadcastHashJoin(
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -306,22 +275,21 @@ case class BroadcastHashJoin(
} else {
ctx.copyResult = true
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
val i = ctx.freshName("i")
- val size = ctx.freshName("size")
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
- |int $size = $matches != null ? $matches.size() : 0;
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|boolean $found = false;
|// the last iteration of this loop is to emit an empty row if there is no matched rows.
- |for (int $i = 0; $i <= $size; $i++) {
- | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
| ${checkCondition.trim}
- | if (!$conditionPassed || ($i == $size && $found)) continue;
+ | if (!$conditionPassed) continue;
| $found = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
@@ -356,7 +324,7 @@ case class BroadcastHashJoin(
""
}
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -369,23 +337,19 @@ case class BroadcastHashJoin(
""".stripMargin
} else {
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|if ($matches == null) continue;
- |int $size = $matches.size();
|boolean $found = false;
- |for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ |while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
| $checkCondition
| $found = true;
- | break;
|}
|if (!$found) continue;
|$numOutput.add(1);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index c298b7dee0..b7c0f3e7d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.joins
import java.util.NoSuchElementException
+import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
import org.apache.spark.util.collection.CompactBuffer
@@ -110,169 +111,113 @@ trait HashJoin {
sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
}
- protected def buildSideKeyGenerator: Projection =
+ protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
- protected def streamSideKeyGenerator: Projection =
+ protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
- newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+ newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
} else {
(r: InternalRow) => true
}
- protected def createResultProjection: (InternalRow) => InternalRow =
- UnsafeProjection.create(self.schema)
-
- protected def hashJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- new Iterator[InternalRow] {
- private[this] var currentStreamedRow: InternalRow = _
- private[this] var currentHashMatches: Seq[InternalRow] = _
- private[this] var currentMatchPosition: Int = -1
-
- // Mutable per row objects.
- private[this] val joinRow = new JoinedRow
- private[this] val resultProjection = createResultProjection
-
- private[this] val joinKeys = streamSideKeyGenerator
-
- override final def hasNext: Boolean = {
- while (true) {
- // check if it's end of current matches
- if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
- currentHashMatches = null
- currentMatchPosition = -1
- }
-
- // find the next match
- while (currentHashMatches == null && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- val key = joinKeys(currentStreamedRow)
- if (!key.anyNull) {
- currentHashMatches = hashedRelation.get(key)
- if (currentHashMatches != null) {
- currentMatchPosition = 0
- }
- }
- }
- if (currentHashMatches == null) {
- return false
- }
-
- // found some matches
- buildSide match {
- case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
- case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
- }
- if (boundCondition(joinRow)) {
- return true
- } else {
- currentMatchPosition += 1
- }
- }
- false // unreachable
- }
-
- override final def next(): InternalRow = {
- // next() could be called without calling hasNext()
- if (hasNext) {
- currentMatchPosition += 1
- numOutputRows += 1
- resultProjection(joinRow)
- } else {
- throw new NoSuchElementException
- }
- }
+ protected def createResultProjection(): (InternalRow) => InternalRow = {
+ if (joinType == LeftSemi) {
+ UnsafeProjection.create(output, output)
+ } else {
+ // Always put the stream side on left to simplify implementation
+ // both of left and right side could be null
+ UnsafeProjection.create(
+ output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true)))
}
}
- @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
-
- @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
- @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
-
- protected[this] def leftOuterIterator(
- key: InternalRow,
- joinedRow: JoinedRow,
- rightIter: Iterable[InternalRow],
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (rightIter != null) {
- rightIter.collect {
- case r if boundCondition(joinedRow.withRight(r)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
- }
- }
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
- } else {
- temp
- }
+ private def innerJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinRow = new JoinedRow
+ val joinKeys = streamSideKeyGenerator()
+ streamIter.flatMap { srow =>
+ joinRow.withLeft(srow)
+ val matches = hashedRelation.get(joinKeys(srow))
+ if (matches != null) {
+ matches.map(joinRow.withRight(_)).filter(boundCondition)
} else {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+ Seq.empty
}
}
- ret.iterator
}
- protected[this] def rightOuterIterator(
- key: InternalRow,
- leftIter: Iterable[InternalRow],
- joinedRow: JoinedRow,
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (leftIter != null) {
- leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
+ private def outerJoin(
+ streamedIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinedRow = new JoinedRow()
+ val keyGenerator = streamSideKeyGenerator()
+ val nullRow = new GenericInternalRow(buildPlan.output.length)
+
+ streamedIter.flatMap { currentRow =>
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withLeft(currentRow)
+ val buildIter = hashedRelation.get(rowKey)
+ new RowIterator {
+ private var found = false
+ override def advanceNext(): Boolean = {
+ while (buildIter != null && buildIter.hasNext) {
+ val nextBuildRow = buildIter.next()
+ if (boundCondition(joinedRow.withRight(nextBuildRow))) {
+ found = true
+ return true
}
}
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- } else {
- temp
+ if (!found) {
+ joinedRow.withRight(nullRow)
+ found = true
+ return true
+ }
+ false
}
- } else {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- }
+ override def getRow: InternalRow = joinedRow
+ }.toScala
}
- ret.iterator
}
- protected def hashSemiJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val joinKeys = streamSideKeyGenerator
+ private def semiJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter { current =>
val key = joinKeys(current)
- lazy val rowBuffer = hashedRelation.get(key)
- val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
+ lazy val buildIter = hashedRelation.get(key)
+ !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
})
- if (r) numOutputRows += 1
- r
+ }
+ }
+
+ protected def join(
+ streamedIter: Iterator[InternalRow],
+ hashed: HashedRelation,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+
+ val joinedIter = joinType match {
+ case Inner =>
+ innerJoin(streamedIter, hashed)
+ case LeftOuter | RightOuter =>
+ outerJoin(streamedIter, hashed)
+ case LeftSemi =>
+ semiJoin(streamedIter, hashed)
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastHashJoin should not take $x as the JoinType")
+ }
+
+ val resultProj = createResultProjection
+ joinedIter.map { r =>
+ numOutputRows += 1
+ resultProj(r)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 91c470d187..5ccb435686 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.{KnownSizeEstimation, Utils}
import org.apache.spark.util.collection.CompactBuffer
@@ -39,17 +38,43 @@ import org.apache.spark.util.collection.CompactBuffer
private[execution] sealed trait HashedRelation {
/**
* Returns matched rows.
+ *
+ * Returns null if there is no matched rows.
*/
- def get(key: InternalRow): Seq[InternalRow]
+ def get(key: InternalRow): Iterator[InternalRow]
/**
* Returns matched rows for a key that has only one column with LongType.
+ *
+ * Returns null if there is no matched rows.
+ */
+ def get(key: Long): Iterator[InternalRow] = {
+ throw new UnsupportedOperationException
+ }
+
+ /**
+ * Returns the matched single row.
*/
- def get(key: Long): Seq[InternalRow] = {
+ def getValue(key: InternalRow): InternalRow
+
+ /**
+ * Returns the matched single row with key that have only one column of LongType.
+ */
+ def getValue(key: Long): InternalRow = {
throw new UnsupportedOperationException
}
/**
+ * Returns true iff all the keys are unique.
+ */
+ def keyIsUnique: Boolean
+
+ /**
+ * Returns a read-only copy of this, to be safely used in current thread.
+ */
+ def asReadOnlyCopy(): HashedRelation
+
+ /**
* Returns the size of used memory.
*/
def getMemorySize: Long = 1L // to make the test happy
@@ -76,44 +101,6 @@ private[execution] sealed trait HashedRelation {
}
}
-/**
- * Interface for a hashed relation that have only one row per key.
- *
- * We should call getValue() for better performance.
- */
-private[execution] trait UniqueHashedRelation extends HashedRelation {
-
- /**
- * Returns the matched single row.
- */
- def getValue(key: InternalRow): InternalRow
-
- /**
- * Returns the matched single row with key that have only one column of LongType.
- */
- def getValue(key: Long): InternalRow = {
- throw new UnsupportedOperationException
- }
-
- override def get(key: InternalRow): Seq[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- CompactBuffer[InternalRow](row)
- } else {
- null
- }
- }
-
- override def get(key: Long): Seq[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- CompactBuffer[InternalRow](row)
- } else {
- null
- }
- }
-}
-
private[execution] object HashedRelation {
/**
@@ -150,6 +137,11 @@ private[joins] class UnsafeHashedRelation(
private[joins] def this() = this(0, null) // Needed for serialization
+ override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues()
+
+ override def asReadOnlyCopy(): UnsafeHashedRelation =
+ new UnsafeHashedRelation(numFields, binaryMap)
+
override def getMemorySize: Long = {
binaryMap.getTotalMemoryConsumption
}
@@ -158,23 +150,39 @@ private[joins] class UnsafeHashedRelation(
binaryMap.getTotalMemoryConsumption
}
- override def get(key: InternalRow): Seq[InternalRow] = {
+ // re-used in get()/getValue()
+ var resultRow = new UnsafeRow(numFields)
+
+ override def get(key: InternalRow): Iterator[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
val map = binaryMap // avoid the compiler error
val loc = new map.Location // this could be allocated in stack
binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
if (loc.isDefined) {
- val buffer = CompactBuffer[UnsafeRow]()
- val row = new UnsafeRow(numFields)
- row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
- buffer += row
- while (loc.nextValue()) {
- val row = new UnsafeRow(numFields)
- row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
- buffer += row
+ new Iterator[UnsafeRow] {
+ private var _hasNext = true
+ override def hasNext: Boolean = _hasNext
+ override def next(): UnsafeRow = {
+ resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ _hasNext = loc.nextValue()
+ resultRow
+ }
}
- buffer
+ } else {
+ null
+ }
+ }
+
+ def getValue(key: InternalRow): InternalRow = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ resultRow
} else {
null
}
@@ -212,6 +220,7 @@ private[joins] class UnsafeHashedRelation(
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
numFields = in.readInt()
+ resultRow = new UnsafeRow(numFields)
val nKeys = in.readInt()
val nValues = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
@@ -263,29 +272,6 @@ private[joins] class UnsafeHashedRelation(
}
}
-/**
- * A HashedRelation for UnsafeRow with unique keys.
- */
-private[joins] final class UniqueUnsafeHashedRelation(
- private var numFields: Int,
- private var binaryMap: BytesToBytesMap)
- extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation {
- def getValue(key: InternalRow): InternalRow = {
- val unsafeKey = key.asInstanceOf[UnsafeRow]
- val map = binaryMap // avoid the compiler error
- val loc = new map.Location // this could be allocated in stack
- binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
- if (loc.isDefined) {
- val row = new UnsafeRow(numFields)
- row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
- row
- } else {
- null
- }
- }
-}
-
private[joins] object UnsafeHashedRelation {
def apply(
@@ -315,17 +301,12 @@ private[joins] object UnsafeHashedRelation {
// Create a mapping of buildKeys -> rows
var numFields = 0
- // Whether all the keys are unique or not
- var allUnique: Boolean = true
while (input.hasNext) {
val row = input.next().asInstanceOf[UnsafeRow]
numFields = row.numFields()
val key = keyGenerator(row)
if (!key.anyNull) {
val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
- if (loc.isDefined) {
- allUnique = false
- }
val success = loc.append(
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
@@ -336,11 +317,7 @@ private[joins] object UnsafeHashedRelation {
}
}
- if (allUnique) {
- new UniqueUnsafeHashedRelation(numFields, binaryMap)
- } else {
- new UnsafeHashedRelation(numFields, binaryMap)
- }
+ new UnsafeHashedRelation(numFields, binaryMap)
}
}
@@ -348,9 +325,12 @@ private[joins] object UnsafeHashedRelation {
* An interface for a hashed relation that the key is a Long.
*/
private[joins] trait LongHashedRelation extends HashedRelation {
- override def get(key: InternalRow): Seq[InternalRow] = {
+ override def get(key: InternalRow): Iterator[InternalRow] = {
get(key.getLong(0))
}
+ override def getValue(key: InternalRow): InternalRow = {
+ getValue(key.getLong(0))
+ }
}
private[joins] final class GeneralLongHashedRelation(
@@ -360,30 +340,18 @@ private[joins] final class GeneralLongHashedRelation(
// Needed for serialization (it is public to make Java serialization work)
def this() = this(null)
- override def get(key: Long): Seq[InternalRow] = hashTable.get(key)
+ override def keyIsUnique: Boolean = false
- override def writeExternal(out: ObjectOutput): Unit = {
- writeBytes(out, SparkSqlSerializer.serialize(hashTable))
- }
-
- override def readExternal(in: ObjectInput): Unit = {
- hashTable = SparkSqlSerializer.deserialize(readBytes(in))
- }
-}
+ override def asReadOnlyCopy(): GeneralLongHashedRelation =
+ new GeneralLongHashedRelation(hashTable)
-private[joins] final class UniqueLongHashedRelation(
- private var hashTable: JavaHashMap[Long, UnsafeRow])
- extends UniqueHashedRelation with LongHashedRelation with Externalizable {
-
- // Needed for serialization (it is public to make Java serialization work)
- def this() = this(null)
-
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
- }
-
- override def getValue(key: Long): InternalRow = {
- hashTable.get(key)
+ override def get(key: Long): Iterator[InternalRow] = {
+ val rows = hashTable.get(key)
+ if (rows != null) {
+ rows.toIterator
+ } else {
+ null
+ }
}
override def writeExternal(out: ObjectOutput): Unit = {
@@ -422,25 +390,36 @@ private[joins] final class LongArrayRelation(
private var offsets: Array[Int],
private var sizes: Array[Int],
private var bytes: Array[Byte]
- ) extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+ ) extends LongHashedRelation with Externalizable {
// Needed for serialization (it is public to make Java serialization work)
def this() = this(0, 0L, null, null, null)
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
+ override def keyIsUnique: Boolean = true
+
+ override def asReadOnlyCopy(): LongArrayRelation = {
+ new LongArrayRelation(numFields, start, offsets, sizes, bytes)
}
override def getMemorySize: Long = {
offsets.length * 4 + sizes.length * 4 + bytes.length
}
+ override def get(key: Long): Iterator[InternalRow] = {
+ val row = getValue(key)
+ if (row != null) {
+ Seq(row).toIterator
+ } else {
+ null
+ }
+ }
+
+ var resultRow = new UnsafeRow(numFields)
override def getValue(key: Long): InternalRow = {
val idx = (key - start).toInt
if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
- val result = new UnsafeRow(numFields)
- result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
- result
+ resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
+ resultRow
} else {
null
}
@@ -461,6 +440,7 @@ private[joins] final class LongArrayRelation(
override def readExternal(in: ObjectInput): Unit = {
numFields = in.readInt()
+ resultRow = new UnsafeRow(numFields)
start = in.readLong()
val length = in.readInt()
// read sizes of rows
@@ -523,44 +503,32 @@ private[joins] object LongHashedRelation {
}
}
- if (keyIsUnique) {
- if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
- // The keys are dense enough, so use LongArrayRelation
- val length = (maxKey - minKey).toInt + 1
- val sizes = new Array[Int](length)
- val offsets = new Array[Int](length)
- var offset = 0
- var i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- offsets(i) = offset
- sizes(i) = rows(0).getSizeInBytes
- offset += sizes(i)
- }
- i += 1
- }
- val bytes = new Array[Byte](offset)
- i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
- }
- i += 1
+ if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
+ // The keys are dense enough, so use LongArrayRelation
+ val length = (maxKey - minKey).toInt + 1
+ val sizes = new Array[Int](length)
+ val offsets = new Array[Int](length)
+ var offset = 0
+ var i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ offsets(i) = offset
+ sizes(i) = rows(0).getSizeInBytes
+ offset += sizes(i)
}
- new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
-
- } else {
- // all the keys are unique, one row per key.
- val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size)
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- uniqHashTable.put(entry.getKey, entry.getValue()(0))
+ i += 1
+ }
+ val bytes = new Array[Byte](offset)
+ i = 0
+ while (i < length) {
+ val rows = hashTable.get(i + minKey)
+ if (rows != null) {
+ rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
}
- new UniqueLongHashedRelation(uniqHashTable)
+ i += 1
}
+ new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
} else {
new GeneralLongHashedRelation(hashTable)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index e3a2eaea5d..c63faacf33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -102,39 +102,9 @@ case class ShuffledHashJoin(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
-
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
- val joinedRow = new JoinedRow
- joinType match {
- case Inner =>
- hashJoin(streamIter, hashed, numOutputRows)
-
- case LeftSemi =>
- hashSemiJoin(streamIter, hashed, numOutputRows)
-
- case LeftOuter =>
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
- streamIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows)
- })
-
- case RightOuter =>
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
- streamIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows)
- })
-
- case x =>
- throw new IllegalArgumentException(
- s"ShuffledHashJoin should not take $x as the JoinType")
- }
+ join(streamIter, hashed, numOutputRows)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 289e1b6db9..3566ef3043 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -166,20 +166,35 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("broadcast hash join") {
- val N = 100 << 20
+ val N = 20 << 20
val M = 1 << 16
val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("Join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count()
}
/*
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X
- Join w long codegen=true 735 / 853 142.7 7.0 7.8X
+ Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X
+ Join w long codegen=true 275 / 352 76.2 13.1 19.4X
+ */
+
+ runBenchmark("Join w long duplicated", N) {
+ val dim = broadcast(sqlContext.range(M).selectExpr("cast(id/10 as long) as k"))
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count()
+ }
+
+ /**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X
+ Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X
*/
val dim2 = broadcast(sqlContext.range(M)
@@ -187,16 +202,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
runBenchmark("Join w 2 ints", N) {
sqlContext.range(N).join(dim2,
- (col("id") bitwiseAND M).cast(IntegerType) === col("k1")
- && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count()
+ (col("id") % M).cast(IntegerType) === col("k1")
+ && (col("id") % M).cast(IntegerType) === col("k2")).count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X
- Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X
+ Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X
+ Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X
*/
val dim3 = broadcast(sqlContext.range(M)
@@ -204,16 +220,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
runBenchmark("Join w 2 longs", N) {
sqlContext.range(N).join(dim3,
- (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+ (col("id") % M) === col("k1") && (col("id") % M) === col("k2"))
.count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 longs codegen=false 12725 / 13158 8.2 121.4 1.0X
- Join w 2 longs codegen=true 6044 / 6771 17.3 57.6 2.1X
+ Join w 2 longs codegen=false 5905 / 6123 3.6 281.6 1.0X
+ Join w 2 longs codegen=true 2230 / 2529 9.4 106.3 2.6X
*/
val dim4 = broadcast(sqlContext.range(M)
@@ -227,34 +244,36 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ Join w 2 longs duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 longs duplicated codegen=false 13066 / 13710 8.0 124.6 1.0X
- Join w 2 longs duplicated codegen=true 7122 / 7277 14.7 67.9 1.8X
+ Join w 2 longs duplicated codegen=false 6420 / 6587 3.3 306.1 1.0X
+ Join w 2 longs duplicated codegen=true 2080 / 2139 10.1 99.2 3.1X
*/
runBenchmark("outer join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "left").count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X
- outer join w long codegen=true 769 / 796 136.3 7.3 19.9X
+ outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X
+ outer join w long codegen=true 216 / 226 97.2 10.3 26.3X
*/
runBenchmark("semi join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X
- semi join w long codegen=true 814 / 934 128.8 7.8 7.1X
+ semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X
+ semi join w long codegen=true 211 / 229 99.2 10.1 22.2X
*/
}
@@ -303,11 +322,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X
- shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X
+ shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X
+ shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X
*/
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index ed4cc1c4c4..ed87a99439 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -34,20 +34,20 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val toUnsafe = UnsafeProjection.create(schema)
- val unsafeData = data.map(toUnsafe(_).copy()).toArray
+ val unsafeData = data.map(toUnsafe(_).copy())
val buildKey = Seq(BoundReference(0, IntegerType, false))
val keyGenerator = UnsafeProjection.create(buildKey)
val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
assert(hashed.isInstanceOf[UnsafeHashedRelation])
- assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
- assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0)))
+ assert(hashed.get(unsafeData(1)).toArray === Array(unsafeData(1)))
assert(hashed.get(toUnsafe(InternalRow(10))) === null)
val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
data2 += unsafeData(2).copy()
- assert(hashed.get(unsafeData(2)) === data2)
+ assert(hashed.get(unsafeData(2)).toArray === data2.toArray)
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
@@ -56,10 +56,10 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val hashed2 = new UnsafeHashedRelation()
hashed2.readExternal(in)
- assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
- assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed2.get(unsafeData(0)).toArray === Array(unsafeData(0)))
+ assert(hashed2.get(unsafeData(1)).toArray === Array(unsafeData(1)))
assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
- assert(hashed2.get(unsafeData(2)) === data2)
+ assert(hashed2.get(unsafeData(2)).toArray === data2)
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)