diff options
Diffstat (limited to 'sql/core')
6 files changed, 268 insertions, 401 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 41e566c27b..67ac9e94ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -68,37 +68,9 @@ case class BroadcastHashJoin( val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow() - val hashTable = broadcastRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) - val keyGenerator = streamSideKeyGenerator - val resultProj = createResultProjection - - joinType match { - case Inner => - hashJoin(streamedIter, hashTable, numOutputRows) - - case LeftOuter => - streamedIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) - } - - case RightOuter => - streamedIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) - } - - case LeftSemi => - hashSemiJoin(streamedIter, hashTable, numOutputRows) - - case x => - throw new IllegalArgumentException( - s"BroadcastHashJoin should not take $x as the JoinType") - } + val hashed = broadcastRelation.value.asReadOnlyCopy() + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) + join(streamedIter, hashed, numOutputRows) } } @@ -132,7 +104,7 @@ case class BroadcastHashJoin( val clsName = broadcastRelation.value.getClass.getName ctx.addMutableState(clsName, relationTerm, s""" - | $relationTerm = ($clsName) $broadcast.value(); + | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) (broadcastRelation, relationTerm) @@ -217,7 +189,7 @@ case class BroadcastHashJoin( case BuildLeft => buildVars ++ input case BuildRight => input ++ buildVars } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side |${keyEv.code} @@ -232,18 +204,15 @@ case class BroadcastHashJoin( } else { ctx.copyResult = true val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName s""" |// generate join key for stream side |${keyEv.code} |// find matches from HashRelation - |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); |if ($matches == null) continue; - |int $size = $matches.size(); - |for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + |while ($matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); | $checkCondition | $numOutput.add(1); | ${consume(ctx, resultVars)} @@ -287,7 +256,7 @@ case class BroadcastHashJoin( case BuildLeft => buildVars ++ input case BuildRight => input ++ buildVars } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side |${keyEv.code} @@ -306,22 +275,21 @@ case class BroadcastHashJoin( } else { ctx.copyResult = true val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val iteratorCls = classOf[Iterator[UnsafeRow]].getName val i = ctx.freshName("i") - val size = ctx.freshName("size") val found = ctx.freshName("found") s""" |// generate join key for stream side |${keyEv.code} |// find matches from HashRelation - |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); - |int $size = $matches != null ? $matches.size() : 0; + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); |boolean $found = false; |// the last iteration of this loop is to emit an empty row if there is no matched rows. - |for (int $i = 0; $i <= $size; $i++) { - | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; + |while ($matches != null && $matches.hasNext() || !$found) { + | UnsafeRow $matched = $matches != null && $matches.hasNext() ? + | (UnsafeRow) $matches.next() : null; | ${checkCondition.trim} - | if (!$conditionPassed || ($i == $size && $found)) continue; + | if (!$conditionPassed) continue; | $found = true; | $numOutput.add(1); | ${consume(ctx, resultVars)} @@ -356,7 +324,7 @@ case class BroadcastHashJoin( "" } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side |${keyEv.code} @@ -369,23 +337,19 @@ case class BroadcastHashJoin( """.stripMargin } else { val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") s""" |// generate join key for stream side |${keyEv.code} |// find matches from HashRelation - |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); |if ($matches == null) continue; - |int $size = $matches.size(); |boolean $found = false; - |for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + |while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); | $checkCondition | $found = true; - | break; |} |if (!$found) continue; |$numOutput.add(1); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index c298b7dee0..b7c0f3e7d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.joins import java.util.NoSuchElementException +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} import org.apache.spark.util.collection.CompactBuffer @@ -110,169 +111,113 @@ trait HashJoin { sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] } - protected def buildSideKeyGenerator: Projection = + protected def buildSideKeyGenerator(): Projection = UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) - protected def streamSideKeyGenerator: Projection = + protected def streamSideKeyGenerator(): UnsafeProjection = UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) } else { (r: InternalRow) => true } - protected def createResultProjection: (InternalRow) => InternalRow = - UnsafeProjection.create(self.schema) - - protected def hashJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - new Iterator[InternalRow] { - private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: Seq[InternalRow] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow - private[this] val resultProjection = createResultProjection - - private[this] val joinKeys = streamSideKeyGenerator - - override final def hasNext: Boolean = { - while (true) { - // check if it's end of current matches - if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) { - currentHashMatches = null - currentMatchPosition = -1 - } - - // find the next match - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - val key = joinKeys(currentStreamedRow) - if (!key.anyNull) { - currentHashMatches = hashedRelation.get(key) - if (currentHashMatches != null) { - currentMatchPosition = 0 - } - } - } - if (currentHashMatches == null) { - return false - } - - // found some matches - buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - if (boundCondition(joinRow)) { - return true - } else { - currentMatchPosition += 1 - } - } - false // unreachable - } - - override final def next(): InternalRow = { - // next() could be called without calling hasNext() - if (hasNext) { - currentMatchPosition += 1 - numOutputRows += 1 - resultProjection(joinRow) - } else { - throw new NoSuchElementException - } - } + protected def createResultProjection(): (InternalRow) => InternalRow = { + if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) + } else { + // Always put the stream side on left to simplify implementation + // both of left and right side could be null + UnsafeProjection.create( + output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true))) } } - @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - - @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - - protected[this] def leftOuterIterator( - key: InternalRow, - joinedRow: JoinedRow, - rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } else { - temp - } + private def innerJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinRow = new JoinedRow + val joinKeys = streamSideKeyGenerator() + streamIter.flatMap { srow => + joinRow.withLeft(srow) + val matches = hashedRelation.get(joinKeys(srow)) + if (matches != null) { + matches.map(joinRow.withRight(_)).filter(boundCondition) } else { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + Seq.empty } } - ret.iterator } - protected[this] def rightOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() + private def outerJoin( + streamedIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinedRow = new JoinedRow() + val keyGenerator = streamSideKeyGenerator() + val nullRow = new GenericInternalRow(buildPlan.output.length) + + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + val buildIter = hashedRelation.get(rowKey) + new RowIterator { + private var found = false + override def advanceNext(): Boolean = { + while (buildIter != null && buildIter.hasNext) { + val nextBuildRow = buildIter.next() + if (boundCondition(joinedRow.withRight(nextBuildRow))) { + found = true + return true } } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } else { - temp + if (!found) { + joinedRow.withRight(nullRow) + found = true + return true + } + false } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } + override def getRow: InternalRow = joinedRow + }.toScala } - ret.iterator } - protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val joinKeys = streamSideKeyGenerator + private def semiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() val joinedRow = new JoinedRow streamIter.filter { current => val key = joinKeys(current) - lazy val rowBuffer = hashedRelation.get(key) - val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists { + lazy val buildIter = hashedRelation.get(key) + !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) }) - if (r) numOutputRows += 1 - r + } + } + + protected def join( + streamedIter: Iterator[InternalRow], + hashed: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + + val joinedIter = joinType match { + case Inner => + innerJoin(streamedIter, hashed) + case LeftOuter | RightOuter => + outerJoin(streamedIter, hashed) + case LeftSemi => + semiJoin(streamedIter, hashed) + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } + + val resultProj = createResultProjection + joinedIter.map { r => + numOutputRows += 1 + resultProj(r) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 91c470d187..5ccb435686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer @@ -39,17 +38,43 @@ import org.apache.spark.util.collection.CompactBuffer private[execution] sealed trait HashedRelation { /** * Returns matched rows. + * + * Returns null if there is no matched rows. */ - def get(key: InternalRow): Seq[InternalRow] + def get(key: InternalRow): Iterator[InternalRow] /** * Returns matched rows for a key that has only one column with LongType. + * + * Returns null if there is no matched rows. + */ + def get(key: Long): Iterator[InternalRow] = { + throw new UnsupportedOperationException + } + + /** + * Returns the matched single row. */ - def get(key: Long): Seq[InternalRow] = { + def getValue(key: InternalRow): InternalRow + + /** + * Returns the matched single row with key that have only one column of LongType. + */ + def getValue(key: Long): InternalRow = { throw new UnsupportedOperationException } /** + * Returns true iff all the keys are unique. + */ + def keyIsUnique: Boolean + + /** + * Returns a read-only copy of this, to be safely used in current thread. + */ + def asReadOnlyCopy(): HashedRelation + + /** * Returns the size of used memory. */ def getMemorySize: Long = 1L // to make the test happy @@ -76,44 +101,6 @@ private[execution] sealed trait HashedRelation { } } -/** - * Interface for a hashed relation that have only one row per key. - * - * We should call getValue() for better performance. - */ -private[execution] trait UniqueHashedRelation extends HashedRelation { - - /** - * Returns the matched single row. - */ - def getValue(key: InternalRow): InternalRow - - /** - * Returns the matched single row with key that have only one column of LongType. - */ - def getValue(key: Long): InternalRow = { - throw new UnsupportedOperationException - } - - override def get(key: InternalRow): Seq[InternalRow] = { - val row = getValue(key) - if (row != null) { - CompactBuffer[InternalRow](row) - } else { - null - } - } - - override def get(key: Long): Seq[InternalRow] = { - val row = getValue(key) - if (row != null) { - CompactBuffer[InternalRow](row) - } else { - null - } - } -} - private[execution] object HashedRelation { /** @@ -150,6 +137,11 @@ private[joins] class UnsafeHashedRelation( private[joins] def this() = this(0, null) // Needed for serialization + override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues() + + override def asReadOnlyCopy(): UnsafeHashedRelation = + new UnsafeHashedRelation(numFields, binaryMap) + override def getMemorySize: Long = { binaryMap.getTotalMemoryConsumption } @@ -158,23 +150,39 @@ private[joins] class UnsafeHashedRelation( binaryMap.getTotalMemoryConsumption } - override def get(key: InternalRow): Seq[InternalRow] = { + // re-used in get()/getValue() + var resultRow = new UnsafeRow(numFields) + + override def get(key: InternalRow): Iterator[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] val map = binaryMap // avoid the compiler error val loc = new map.Location // this could be allocated in stack binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) if (loc.isDefined) { - val buffer = CompactBuffer[UnsafeRow]() - val row = new UnsafeRow(numFields) - row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - buffer += row - while (loc.nextValue()) { - val row = new UnsafeRow(numFields) - row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - buffer += row + new Iterator[UnsafeRow] { + private var _hasNext = true + override def hasNext: Boolean = _hasNext + override def next(): UnsafeRow = { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + _hasNext = loc.nextValue() + resultRow + } } - buffer + } else { + null + } + } + + def getValue(key: InternalRow): InternalRow = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + resultRow } else { null } @@ -212,6 +220,7 @@ private[joins] class UnsafeHashedRelation( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { numFields = in.readInt() + resultRow = new UnsafeRow(numFields) val nKeys = in.readInt() val nValues = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory @@ -263,29 +272,6 @@ private[joins] class UnsafeHashedRelation( } } -/** - * A HashedRelation for UnsafeRow with unique keys. - */ -private[joins] final class UniqueUnsafeHashedRelation( - private var numFields: Int, - private var binaryMap: BytesToBytesMap) - extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation { - def getValue(key: InternalRow): InternalRow = { - val unsafeKey = key.asInstanceOf[UnsafeRow] - val map = binaryMap // avoid the compiler error - val loc = new map.Location // this could be allocated in stack - binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) - if (loc.isDefined) { - val row = new UnsafeRow(numFields) - row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - row - } else { - null - } - } -} - private[joins] object UnsafeHashedRelation { def apply( @@ -315,17 +301,12 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows var numFields = 0 - // Whether all the keys are unique or not - var allUnique: Boolean = true while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] numFields = row.numFields() val key = keyGenerator(row) if (!key.anyNull) { val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) - if (loc.isDefined) { - allUnique = false - } val success = loc.append( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) @@ -336,11 +317,7 @@ private[joins] object UnsafeHashedRelation { } } - if (allUnique) { - new UniqueUnsafeHashedRelation(numFields, binaryMap) - } else { - new UnsafeHashedRelation(numFields, binaryMap) - } + new UnsafeHashedRelation(numFields, binaryMap) } } @@ -348,9 +325,12 @@ private[joins] object UnsafeHashedRelation { * An interface for a hashed relation that the key is a Long. */ private[joins] trait LongHashedRelation extends HashedRelation { - override def get(key: InternalRow): Seq[InternalRow] = { + override def get(key: InternalRow): Iterator[InternalRow] = { get(key.getLong(0)) } + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } } private[joins] final class GeneralLongHashedRelation( @@ -360,30 +340,18 @@ private[joins] final class GeneralLongHashedRelation( // Needed for serialization (it is public to make Java serialization work) def this() = this(null) - override def get(key: Long): Seq[InternalRow] = hashTable.get(key) + override def keyIsUnique: Boolean = false - override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) - } - - override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) - } -} + override def asReadOnlyCopy(): GeneralLongHashedRelation = + new GeneralLongHashedRelation(hashTable) -private[joins] final class UniqueLongHashedRelation( - private var hashTable: JavaHashMap[Long, UnsafeRow]) - extends UniqueHashedRelation with LongHashedRelation with Externalizable { - - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) - - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) - } - - override def getValue(key: Long): InternalRow = { - hashTable.get(key) + override def get(key: Long): Iterator[InternalRow] = { + val rows = hashTable.get(key) + if (rows != null) { + rows.toIterator + } else { + null + } } override def writeExternal(out: ObjectOutput): Unit = { @@ -422,25 +390,36 @@ private[joins] final class LongArrayRelation( private var offsets: Array[Int], private var sizes: Array[Int], private var bytes: Array[Byte] - ) extends UniqueHashedRelation with LongHashedRelation with Externalizable { + ) extends LongHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) def this() = this(0, 0L, null, null, null) - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) + override def keyIsUnique: Boolean = true + + override def asReadOnlyCopy(): LongArrayRelation = { + new LongArrayRelation(numFields, start, offsets, sizes, bytes) } override def getMemorySize: Long = { offsets.length * 4 + sizes.length * 4 + bytes.length } + override def get(key: Long): Iterator[InternalRow] = { + val row = getValue(key) + if (row != null) { + Seq(row).toIterator + } else { + null + } + } + + var resultRow = new UnsafeRow(numFields) override def getValue(key: Long): InternalRow = { val idx = (key - start).toInt if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { - val result = new UnsafeRow(numFields) - result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) - result + resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) + resultRow } else { null } @@ -461,6 +440,7 @@ private[joins] final class LongArrayRelation( override def readExternal(in: ObjectInput): Unit = { numFields = in.readInt() + resultRow = new UnsafeRow(numFields) start = in.readLong() val length = in.readInt() // read sizes of rows @@ -523,44 +503,32 @@ private[joins] object LongHashedRelation { } } - if (keyIsUnique) { - if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { - // The keys are dense enough, so use LongArrayRelation - val length = (maxKey - minKey).toInt + 1 - val sizes = new Array[Int](length) - val offsets = new Array[Int](length) - var offset = 0 - var i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - offsets(i) = offset - sizes(i) = rows(0).getSizeInBytes - offset += sizes(i) - } - i += 1 - } - val bytes = new Array[Byte](offset) - i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) - } - i += 1 + if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { + // The keys are dense enough, so use LongArrayRelation + val length = (maxKey - minKey).toInt + 1 + val sizes = new Array[Int](length) + val offsets = new Array[Int](length) + var offset = 0 + var i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + offsets(i) = offset + sizes(i) = rows(0).getSizeInBytes + offset += sizes(i) } - new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) - - } else { - // all the keys are unique, one row per key. - val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size) - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - uniqHashTable.put(entry.getKey, entry.getValue()(0)) + i += 1 + } + val bytes = new Array[Byte](offset) + i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) } - new UniqueLongHashedRelation(uniqHashTable) + i += 1 } + new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) } else { new GeneralLongHashedRelation(hashTable) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index e3a2eaea5d..c63faacf33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -102,39 +102,9 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) - val joinedRow = new JoinedRow - joinType match { - case Inner => - hashJoin(streamIter, hashed, numOutputRows) - - case LeftSemi => - hashSemiJoin(streamIter, hashed, numOutputRows) - - case LeftOuter => - val keyGenerator = streamSideKeyGenerator - val resultProj = createResultProjection - streamIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - val keyGenerator = streamSideKeyGenerator - val resultProj = createResultProjection - streamIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case x => - throw new IllegalArgumentException( - s"ShuffledHashJoin should not take $x as the JoinType") - } + join(streamIter, hashed, numOutputRows) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 289e1b6db9..3566ef3043 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -166,20 +166,35 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("broadcast hash join") { - val N = 100 << 20 + val N = 20 << 20 val M = 1 << 16 val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("Join w long", N) { - sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count() + sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count() } /* + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X - Join w long codegen=true 735 / 853 142.7 7.0 7.8X + Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X + Join w long codegen=true 275 / 352 76.2 13.1 19.4X + */ + + runBenchmark("Join w long duplicated", N) { + val dim = broadcast(sqlContext.range(M).selectExpr("cast(id/10 as long) as k")) + sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count() + } + + /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X + Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X */ val dim2 = broadcast(sqlContext.range(M) @@ -187,16 +202,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { runBenchmark("Join w 2 ints", N) { sqlContext.range(N).join(dim2, - (col("id") bitwiseAND M).cast(IntegerType) === col("k1") - && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count() + (col("id") % M).cast(IntegerType) === col("k1") + && (col("id") % M).cast(IntegerType) === col("k2")).count() } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X - Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X + Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X + Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X */ val dim3 = broadcast(sqlContext.range(M) @@ -204,16 +220,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { runBenchmark("Join w 2 longs", N) { sqlContext.range(N).join(dim3, - (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) + (col("id") % M) === col("k1") && (col("id") % M) === col("k2")) .count() } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 longs codegen=false 12725 / 13158 8.2 121.4 1.0X - Join w 2 longs codegen=true 6044 / 6771 17.3 57.6 2.1X + Join w 2 longs codegen=false 5905 / 6123 3.6 281.6 1.0X + Join w 2 longs codegen=true 2230 / 2529 9.4 106.3 2.6X */ val dim4 = broadcast(sqlContext.range(M) @@ -227,34 +244,36 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w 2 longs duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 longs duplicated codegen=false 13066 / 13710 8.0 124.6 1.0X - Join w 2 longs duplicated codegen=true 7122 / 7277 14.7 67.9 1.8X + Join w 2 longs duplicated codegen=false 6420 / 6587 3.3 306.1 1.0X + Join w 2 longs duplicated codegen=true 2080 / 2139 10.1 99.2 3.1X */ runBenchmark("outer join w long", N) { - sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count() + sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "left").count() } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X - outer join w long codegen=true 769 / 796 136.3 7.3 19.9X + outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X + outer join w long codegen=true 216 / 226 97.2 10.3 26.3X */ runBenchmark("semi join w long", N) { - sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count() + sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count() } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X - semi join w long codegen=true 814 / 934 128.8 7.8 7.1X + semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X + semi join w long codegen=true 211 / 229 99.2 10.1 22.2X */ } @@ -303,11 +322,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X - shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X + shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X + shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ed4cc1c4c4..ed87a99439 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -34,20 +34,20 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) - val unsafeData = data.map(toUnsafe(_).copy()).toArray + val unsafeData = data.map(toUnsafe(_).copy()) val buildKey = Seq(BoundReference(0, IntegerType, false)) val keyGenerator = UnsafeProjection.create(buildKey) val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) - assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) + assert(hashed.get(unsafeData(1)).toArray === Array(unsafeData(1))) assert(hashed.get(toUnsafe(InternalRow(10))) === null) val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) data2 += unsafeData(2).copy() - assert(hashed.get(unsafeData(2)) === data2) + assert(hashed.get(unsafeData(2)).toArray === data2.toArray) val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) @@ -56,10 +56,10 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) val hashed2 = new UnsafeHashedRelation() hashed2.readExternal(in) - assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) - assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(unsafeData(0)).toArray === Array(unsafeData(0))) + assert(hashed2.get(unsafeData(1)).toArray === Array(unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) - assert(hashed2.get(unsafeData(2)) === data2) + assert(hashed2.get(unsafeData(2)).toArray === data2) val os2 = new ByteArrayOutputStream() val out2 = new ObjectOutputStream(os2) |