diff options
20 files changed, 444 insertions, 135 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6ce03a48e9..7f08bf7b74 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions; import java.io.IOException; import java.io.OutputStream; -import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.UTF8String; @@ -354,7 +355,7 @@ public final class UnsafeRow extends MutableRow { * This method is only supported on UnsafeRows that do not use ObjectPools. */ @Override - public InternalRow copy() { + public UnsafeRow copy() { if (pool != null) { throw new UnsupportedOperationException( "Copy is not supported for UnsafeRows that use object pools"); @@ -405,7 +406,50 @@ public final class UnsafeRow extends MutableRow { } @Override + public int hashCode() { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UnsafeRow) { + UnsafeRow o = (UnsafeRow) other; + return (sizeInBytes == o.sizeInBytes) && + ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); + } + return false; + } + + /** + * Returns the underlying bytes for this UnsafeRow. + */ + public byte[] getBytes() { + if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + && (((byte[]) baseObject).length == sizeInBytes)) { + return (byte[]) baseObject; + } else { + byte[] bytes = new byte[sizeInBytes]; + PlatformDependent.copyMemory(baseObject, baseOffset, bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + return bytes; + } + } + + // This is for debugging + @Override + public String toString() { + StringBuilder build = new StringBuilder("["); + for (int i = 0; i < sizeInBytes; i += 8) { + build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); + build.append(','); + } + build.append(']'); + return build.toString(); + } + + @Override public boolean anyNull() { - return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); + return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index d1d81c87bb..39fd6e1bc6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -28,11 +28,10 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; @@ -176,12 +175,7 @@ final class UnsafeExternalRowSorter { */ public static boolean supportsSchema(StructType schema) { // TODO: add spilling note to explain why we do this for now: - for (StructField field : schema.fields()) { - if (!UnsafeColumnWriter.canEmbed(field.dataType())) { - return false; - } - } - return true; + return UnsafeProjection.canSupport(schema); } private static final class RowComparator extends RecordComparator { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index b10a3c8774..4a13b687bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ /** @@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, $dataType]" - override def eval(input: InternalRow): Any = input(ordinal) + // Use special getter for primitive types (for UnsafeRow) + override def eval(input: InternalRow): Any = { + if (input.isNullAt(ordinal)) { + null + } else { + dataType match { + case BooleanType => input.getBoolean(ordinal) + case ByteType => input.getByte(ordinal) + case ShortType => input.getShort(ordinal) + case IntegerType | DateType => input.getInt(ordinal) + case LongType | TimestampType => input.getLong(ordinal) + case FloatType => input.getFloat(ordinal) + case DoubleType => input.getDouble(ordinal) + case _ => input.get(ordinal) + } + } + } override def name: String = s"i[$ordinal]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 24b01ea551..69758e653e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -83,12 +83,42 @@ abstract class UnsafeProjection extends Projection { } object UnsafeProjection { + + /* + * Returns whether UnsafeProjection can support given StructType, Array[DataType] or + * Seq[Expression]. + */ + def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) + def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) + def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) + + /** + * Returns an UnsafeProjection for given StructType. + */ def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) - def create(fields: Seq[DataType]): UnsafeProjection = { + /** + * Returns an UnsafeProjection for given Array of DataTypes. + */ + def create(fields: Array[DataType]): UnsafeProjection = { val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + create(exprs) + } + + /** + * Returns an UnsafeProjection for given sequence of Expressions (bounded). + */ + def create(exprs: Seq[Expression]): UnsafeProjection = { GenerateUnsafeProjection.generate(exprs) } + + /** + * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. + */ + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { + create(exprs.map(BindReferences.bindReference(_, inputSchema))) + } } /** @@ -96,6 +126,8 @@ object UnsafeProjection { */ case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => new BoundReference(idx, dt, true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 7ffdce60d2..abaa4a6ce8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -62,7 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + val hashed = buildHashRelation(input.iterator) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index ab757fc7de..c9d1a880f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.joins
+import scala.concurrent._
+import scala.concurrent.duration._
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils
-import scala.collection.JavaConversions._
-import scala.concurrent._
-import scala.concurrent.duration._
-
/**
* :: DeveloperApi ::
* Performs a outer hash join for two child relations. When the output RDD of this operator is
@@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
- private[this] lazy val (buildPlan, streamedPlan) = joinType match {
- case RightOuter => (left, right)
- case LeftOuter => (right, left)
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashOuterJoin should not take $x as the JoinType")
- }
-
- private[this] lazy val (buildKeys, streamedKeys) = joinType match {
- case RightOuter => (leftKeys, rightKeys)
- case LeftOuter => (rightKeys, leftKeys)
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashOuterJoin should not take $x as the JoinType")
- }
-
@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- // buildHashTable uses code-generated rows as keys, which are not serializable
- val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))
+ val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
@@ -89,21 +71,21 @@ case class BroadcastHashOuterJoin( streamedPlan.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
- val keyGenerator = newProjection(streamedKeys, streamedPlan.output)
+ val keyGenerator = streamedKeyGenerator
joinType match {
case LeftOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
+ leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
})
case RightOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+ rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
})
case x =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 2750f58b00..f71c0ce352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -40,15 +40,14 @@ case class BroadcastLeftSemiJoinHash( val buildIter = right.execute().map(_.copy()).collect().toIterator if (condition.isEmpty) { - // rowKey may be not serializable (from codegen) - val hashSet = buildKeyHashSet(buildIter, copy = true) + val hashSet = buildKeyHashSet(buildIter) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val hashRelation = buildHashRelation(buildIter) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 60b4266fad..700636966f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -44,6 +44,19 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + + @transient private[this] lazy val resultProjection: Projection = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { @@ -74,6 +87,7 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -86,11 +100,11 @@ case class BroadcastNestedLoopJoin( val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case _ => @@ -100,9 +114,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() + matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() + matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -110,12 +124,9 @@ case class BroadcastNestedLoopJoin( } val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } + val allIncludedBroadcastTuples = includedBroadcastTuples.fold( + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + )(_ ++ _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -127,8 +138,10 @@ case class BroadcastNestedLoopJoin( while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => + buf += resultProjection(new JoinedRow(leftNulls, rel(i))) + case (LeftOuter | FullOuter, BuildLeft) => + buf += resultProjection(new JoinedRow(rel(i), rightNulls)) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ff85ea3f6a..ae34409bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,11 +44,20 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator: Projection = - newProjection(buildKeys, buildPlan.output) + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe - @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = - newMutableProjection(streamedKeys, streamedPlan.output) + @transient protected lazy val streamSideKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newMutableProjection(streamedKeys, streamedPlan.output)() + } protected def hashJoin( streamIter: Iterator[InternalRow], @@ -61,8 +70,17 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow2 + private[this] val resultProjection: Projection = { + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } - private[this] val joinKeys = streamSideKeyGenerator() + private[this] val joinKeys = streamSideKeyGenerator override final def hasNext: Boolean = (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || @@ -74,7 +92,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 - ret + resultProjection(ret) } /** @@ -89,8 +107,9 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashedRelation.get(joinKeys.currentValue) + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashedRelation.get(key) } } @@ -103,4 +122,12 @@ trait HashJoin { } } } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 74a7db7761..6bf2f82954 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -38,7 +38,7 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan -override def outputPartitioning: Partitioning = joinType match { + override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -59,6 +59,49 @@ override def outputPartitioning: Partitioning = joinType match { } } + protected[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + protected[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && joinType != FullOuter + && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe + + protected[this] def streamedKeyGenerator(): Projection = { + if (supportUnsafe) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newProjection(streamedKeys, streamedPlan.output) + } + } + + @transient private[this] lazy val resultProjection: Projection = { + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -76,16 +119,20 @@ override def outputPartitioning: Partitioning = joinType match { rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + } + } else { + List.empty } if (temp.isEmpty) { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } } ret.iterator @@ -97,17 +144,21 @@ override def outputPartitioning: Partitioning = joinType match { joinedRow: JoinedRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy() + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => + resultProjection(joinedRow).copy() + } + } else { + List.empty } if (temp.isEmpty) { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } } ret.iterator @@ -159,6 +210,7 @@ override def outputPartitioning: Partitioning = joinType match { } } + // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { @@ -178,4 +230,12 @@ override def outputPartitioning: Partitioning = joinType match { hashTable } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 1b983bc3a9..7f49264d40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -32,34 +32,45 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - @transient protected lazy val rightKeyGenerator: Projection = - newProjection(rightKeys, right.output) + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(left.schema)) + } + + override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = supportUnsafe + + @transient protected lazy val leftKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newMutableProjection(leftKeys, left.output)() + } - @transient protected lazy val leftKeyGenerator: () => MutableProjection = - newMutableProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newMutableProjection(rightKeys, right.output)() + } @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet( - buildIter: Iterator[InternalRow], - copy: Boolean): java.util.Set[InternalRow] = { + protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() var currentRow: InternalRow = null // Create a Hash set of buildKeys + val rightKey = rightKeyGenerator while (buildIter.hasNext) { currentRow = buildIter.next() - val rowKey = rightKeyGenerator(currentRow) + val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { - if (copy) { - hashSet.add(rowKey.copy()) - } else { - // rowKey may be not serializable (from codegen) - hashSet.add(rowKey) - } + hashSet.add(rowKey.copy()) } } } @@ -67,25 +78,34 @@ trait HashSemiJoin { } protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator() - val joinedRow = new JoinedRow + streamIter: Iterator[InternalRow], + hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator streamIter.filter(current => { - lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue) - !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { - (build: InternalRow) => boundCondition(joinedRow(current, build)) - } + val key = joinKeys(current) + !key.anyNull && hashSet.contains(key) }) } + protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, rightKeys, right) + } else { + HashedRelation(buildIter, newProjection(rightKeys, right.output)) + } + } + protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator() + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow - streamIter.filter(current => { - !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue) - }) + streamIter.filter { current => + val key = joinKeys(current) + lazy val rowBuffer = hashedRelation.get(key) + !key.anyNull && rowBuffer != null && rowBuffer.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6b51f5d415..8d5731afd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.CompactBuffer @@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR } } - // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. @@ -148,3 +148,80 @@ private[joins] object HashedRelation { } } } + + +/** + * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a + * sequence of values. + * + * TODO(davies): use BytesToBytesMap + */ +private[joins] final class UnsafeHashedRelation( + private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization + + override def get(key: InternalRow): CompactBuffer[InternalRow] = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + // Thanks to type eraser + hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] object UnsafeHashedRelation { + + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + buildPlan: SparkPlan, + sizeEstimate: Int = 64): HashedRelation = { + val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) + apply(input, boundedKeys, buildPlan.schema, sizeEstimate) + } + + // Used for tests + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + rowSchema: StructType, + sizeEstimate: Int): HashedRelation = { + + // TODO: Use BytesToBytesMap. + val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + val toUnsafe = UnsafeProjection.create(rowSchema) + val keyGenerator = UnsafeProjection.create(buildKeys) + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + val currentRow = input.next() + val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { + currentRow.asInstanceOf[UnsafeRow] + } else { + toUnsafe(currentRow) + } + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(rowKey.copy(), newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + new UnsafeHashedRelation(hashTable) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index db5be9f453..4443455ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -39,6 +39,9 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output + override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 9eaac817d9..874712a4e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -43,10 +43,10 @@ case class LeftSemiJoinHash( protected override def doExecute(): RDD[InternalRow] = { right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter, copy = false) + val hashSet = buildKeyHashSet(buildIter) hashSemiJoin(streamIter, hashSet) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val hashRelation = buildHashRelation(buildIter) hashSemiJoin(streamIter, hashRelation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60..948d0ccebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,7 +45,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + val hashed = buildHashRelation(buildIter) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ab0a6ad56a..f54f1edd38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,24 +50,25 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?)
joinType match {
case LeftOuter =>
- val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
- val keyGenerator = newProjection(leftKeys, left.output)
+ val hashed = buildHashRelation(rightIter)
+ val keyGenerator = streamedKeyGenerator()
leftIter.flatMap( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST))
+ leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey))
})
case RightOuter =>
- val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
- val keyGenerator = newProjection(rightKeys, right.output)
+ val hashed = buildHashRelation(leftIter)
+ val keyGenerator = streamedKeyGenerator()
rightIter.flatMap ( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+ rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow)
})
case FullOuter =>
+ // TODO(davies): use UnsafeRow
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
(leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 421d510e67..29f3beb3cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule */ @DeveloperApi case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { + + require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") + override def output: Seq[Attribute] = child.output override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false @@ -93,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { } case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, then convert everything - // to unsafe rows - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + // If this operator's children produce both unsafe and safe rows, + // convert everything unsafe rows if all the schema of them are support by UnsafeRow + if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } } } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 3854dc1b7a..d36e263937 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -31,7 +31,7 @@ class UnsafeRowSuite extends SparkFunSuite { test("writeToStream") { val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) val arrayBackedUnsafeRow: UnsafeRow = - UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row) + UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) val bytesFromArrayBackedRow: Array[Byte] = { val baos = new ByteArrayOutputStream() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9d9858b1c6..9dd2220f09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.{StructField, StructType, IntegerType} import org.apache.spark.util.collection.CompactBuffer @@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) assert(hashed.get(InternalRow(10)) === null) val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) - assert(hashed.get(data(2)) == data2) + assert(hashed.get(data(2)) === data2) } test("UniqueKeyHashedRelation") { @@ -49,15 +51,40 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) == data(0)) - assert(uniqHashed.getValue(data(1)) == data(1)) - assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(InternalRow(10)) == null) + assert(uniqHashed.getValue(data(0)) === data(0)) + assert(uniqHashed.getValue(data(1)) === data(1)) + assert(uniqHashed.getValue(data(2)) === data(2)) + assert(uniqHashed.getValue(InternalRow(10)) === null) + } + + test("UnsafeHashedRelation") { + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val buildKey = Seq(BoundReference(0, IntegerType, false)) + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) + assert(hashed.isInstanceOf[UnsafeHashedRelation]) + + val toUnsafeKey = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafeKey(_).copy()).toArray + assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) + + val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) + data2 += unsafeData(2).copy() + assert(hashed.get(unsafeData(2)) === data2) + + val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) + .asInstanceOf[UnsafeHashedRelation] + assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(unsafeData(2)) === data2) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 85cd02469a..61f483ced3 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -44,12 +44,16 @@ public final class Murmur3_x86_32 { return fmix(h1, 4); } - public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; - for (int offset = 0; offset < lengthInBytes; offset += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } |