aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java50
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala82
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala74
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala85
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala49
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java10
20 files changed, 444 insertions, 135 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6ce03a48e9..7f08bf7b74 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions;
import java.io.IOException;
import java.io.OutputStream;
-import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.UTF8String;
@@ -354,7 +355,7 @@ public final class UnsafeRow extends MutableRow {
* This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
- public InternalRow copy() {
+ public UnsafeRow copy() {
if (pool != null) {
throw new UnsupportedOperationException(
"Copy is not supported for UnsafeRows that use object pools");
@@ -405,7 +406,50 @@ public final class UnsafeRow extends MutableRow {
}
@Override
+ public int hashCode() {
+ return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof UnsafeRow) {
+ UnsafeRow o = (UnsafeRow) other;
+ return (sizeInBytes == o.sizeInBytes) &&
+ ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
+ sizeInBytes);
+ }
+ return false;
+ }
+
+ /**
+ * Returns the underlying bytes for this UnsafeRow.
+ */
+ public byte[] getBytes() {
+ if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
+ && (((byte[]) baseObject).length == sizeInBytes)) {
+ return (byte[]) baseObject;
+ } else {
+ byte[] bytes = new byte[sizeInBytes];
+ PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
+ return bytes;
+ }
+ }
+
+ // This is for debugging
+ @Override
+ public String toString() {
+ StringBuilder build = new StringBuilder("[");
+ for (int i = 0; i < sizeInBytes; i += 8) {
+ build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
+ build.append(',');
+ }
+ build.append(']');
+ return build.toString();
+ }
+
+ @Override
public boolean anyNull() {
- return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
+ return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index d1d81c87bb..39fd6e1bc6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -28,11 +28,10 @@ import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.sql.AbstractScalaRowIterator;
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
@@ -176,12 +175,7 @@ final class UnsafeExternalRowSorter {
*/
public static boolean supportsSchema(StructType schema) {
// TODO: add spilling note to explain why we do this for now:
- for (StructField field : schema.fields()) {
- if (!UnsafeColumnWriter.canEmbed(field.dataType())) {
- return false;
- }
- }
- return true;
+ return UnsafeProjection.canSupport(schema);
}
private static final class RowComparator extends RecordComparator {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index b10a3c8774..4a13b687bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._
/**
@@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def toString: String = s"input[$ordinal, $dataType]"
- override def eval(input: InternalRow): Any = input(ordinal)
+ // Use special getter for primitive types (for UnsafeRow)
+ override def eval(input: InternalRow): Any = {
+ if (input.isNullAt(ordinal)) {
+ null
+ } else {
+ dataType match {
+ case BooleanType => input.getBoolean(ordinal)
+ case ByteType => input.getByte(ordinal)
+ case ShortType => input.getShort(ordinal)
+ case IntegerType | DateType => input.getInt(ordinal)
+ case LongType | TimestampType => input.getLong(ordinal)
+ case FloatType => input.getFloat(ordinal)
+ case DoubleType => input.getDouble(ordinal)
+ case _ => input.get(ordinal)
+ }
+ }
+ }
override def name: String = s"i[$ordinal]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 24b01ea551..69758e653e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -83,12 +83,42 @@ abstract class UnsafeProjection extends Projection {
}
object UnsafeProjection {
+
+ /*
+ * Returns whether UnsafeProjection can support given StructType, Array[DataType] or
+ * Seq[Expression].
+ */
+ def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType))
+ def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_))
+ def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray)
+
+ /**
+ * Returns an UnsafeProjection for given StructType.
+ */
def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))
- def create(fields: Seq[DataType]): UnsafeProjection = {
+ /**
+ * Returns an UnsafeProjection for given Array of DataTypes.
+ */
+ def create(fields: Array[DataType]): UnsafeProjection = {
val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
+ create(exprs)
+ }
+
+ /**
+ * Returns an UnsafeProjection for given sequence of Expressions (bounded).
+ */
+ def create(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
}
+
+ /**
+ * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to
+ * `inputSchema`.
+ */
+ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
+ create(exprs.map(BindReferences.bindReference(_, inputSchema)))
+ }
}
/**
@@ -96,6 +126,8 @@ object UnsafeProjection {
*/
case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {
+ def this(schema: StructType) = this(schema.fields.map(_.dataType))
+
private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
new BoundReference(idx, dt, true)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 7ffdce60d2..abaa4a6ce8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -62,7 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
+ val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index ab757fc7de..c9d1a880f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.execution.joins
+import scala.concurrent._
+import scala.concurrent.duration._
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils
-import scala.collection.JavaConversions._
-import scala.concurrent._
-import scala.concurrent.duration._
-
/**
* :: DeveloperApi ::
* Performs a outer hash join for two child relations. When the output RDD of this operator is
@@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
- private[this] lazy val (buildPlan, streamedPlan) = joinType match {
- case RightOuter => (left, right)
- case LeftOuter => (right, left)
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashOuterJoin should not take $x as the JoinType")
- }
-
- private[this] lazy val (buildKeys, streamedKeys) = joinType match {
- case RightOuter => (leftKeys, rightKeys)
- case LeftOuter => (rightKeys, leftKeys)
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashOuterJoin should not take $x as the JoinType")
- }
-
@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- // buildHashTable uses code-generated rows as keys, which are not serializable
- val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))
+ val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
@@ -89,21 +71,21 @@ case class BroadcastHashOuterJoin(
streamedPlan.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
- val keyGenerator = newProjection(streamedKeys, streamedPlan.output)
+ val keyGenerator = streamedKeyGenerator
joinType match {
case LeftOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
+ leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
})
case RightOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+ rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
})
case x =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 2750f58b00..f71c0ce352 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -40,15 +40,14 @@ case class BroadcastLeftSemiJoinHash(
val buildIter = right.execute().map(_.copy()).collect().toIterator
if (condition.isEmpty) {
- // rowKey may be not serializable (from codegen)
- val hashSet = buildKeyHashSet(buildIter, copy = true)
+ val hashSet = buildKeyHashSet(buildIter)
val broadcastedRelation = sparkContext.broadcast(hashSet)
left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
- val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+ val hashRelation = buildHashRelation(buildIter)
val broadcastedRelation = sparkContext.broadcast(hashRelation)
left.execute().mapPartitions { streamIter =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 60b4266fad..700636966f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -44,6 +44,19 @@ case class BroadcastNestedLoopJoin(
case BuildLeft => (right, left)
}
+ override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+
+ @transient private[this] lazy val resultProjection: Projection = {
+ if (outputsUnsafeRows) {
+ UnsafeProjection.create(schema)
+ } else {
+ new Projection {
+ override def apply(r: InternalRow): InternalRow = r
+ }
+ }
+ }
+
override def outputPartitioning: Partitioning = streamed.outputPartitioning
override def output: Seq[Attribute] = {
@@ -74,6 +87,7 @@ case class BroadcastNestedLoopJoin(
val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow
+
val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
@@ -86,11 +100,11 @@ case class BroadcastNestedLoopJoin(
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
- matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
+ matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
- matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
+ matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case _ =>
@@ -100,9 +114,9 @@ case class BroadcastNestedLoopJoin(
(streamRowMatched, joinType, buildSide) match {
case (false, LeftOuter | FullOuter, BuildRight) =>
- matchedRows += joinedRow(streamedRow, rightNulls).copy()
+ matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy()
case (false, RightOuter | FullOuter, BuildLeft) =>
- matchedRows += joinedRow(leftNulls, streamedRow).copy()
+ matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy()
case _ =>
}
}
@@ -110,12 +124,9 @@ case class BroadcastNestedLoopJoin(
}
val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
- val allIncludedBroadcastTuples =
- if (includedBroadcastTuples.count == 0) {
- new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
- } else {
- includedBroadcastTuples.reduce(_ ++ _)
- }
+ val allIncludedBroadcastTuples = includedBroadcastTuples.fold(
+ new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
+ )(_ ++ _)
val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
@@ -127,8 +138,10 @@ case class BroadcastNestedLoopJoin(
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
- case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
- case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
+ case (RightOuter | FullOuter, BuildRight) =>
+ buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
+ case (LeftOuter | FullOuter, BuildLeft) =>
+ buf += resultProjection(new JoinedRow(rel(i), rightNulls))
case _ =>
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index ff85ea3f6a..ae34409bcf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -44,11 +44,20 @@ trait HashJoin {
override def output: Seq[Attribute] = left.output ++ right.output
- @transient protected lazy val buildSideKeyGenerator: Projection =
- newProjection(buildKeys, buildPlan.output)
+ protected[this] def supportUnsafe: Boolean = {
+ (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
+ && UnsafeProjection.canSupport(self.schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = supportUnsafe
+ override def canProcessUnsafeRows: Boolean = supportUnsafe
- @transient protected lazy val streamSideKeyGenerator: () => MutableProjection =
- newMutableProjection(streamedKeys, streamedPlan.output)
+ @transient protected lazy val streamSideKeyGenerator: Projection =
+ if (supportUnsafe) {
+ UnsafeProjection.create(streamedKeys, streamedPlan.output)
+ } else {
+ newMutableProjection(streamedKeys, streamedPlan.output)()
+ }
protected def hashJoin(
streamIter: Iterator[InternalRow],
@@ -61,8 +70,17 @@ trait HashJoin {
// Mutable per row objects.
private[this] val joinRow = new JoinedRow2
+ private[this] val resultProjection: Projection = {
+ if (supportUnsafe) {
+ UnsafeProjection.create(self.schema)
+ } else {
+ new Projection {
+ override def apply(r: InternalRow): InternalRow = r
+ }
+ }
+ }
- private[this] val joinKeys = streamSideKeyGenerator()
+ private[this] val joinKeys = streamSideKeyGenerator
override final def hasNext: Boolean =
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
@@ -74,7 +92,7 @@ trait HashJoin {
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
currentMatchPosition += 1
- ret
+ resultProjection(ret)
}
/**
@@ -89,8 +107,9 @@ trait HashJoin {
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
- if (!joinKeys(currentStreamedRow).anyNull) {
- currentHashMatches = hashedRelation.get(joinKeys.currentValue)
+ val key = joinKeys(currentStreamedRow)
+ if (!key.anyNull) {
+ currentHashMatches = hashedRelation.get(key)
}
}
@@ -103,4 +122,12 @@ trait HashJoin {
}
}
}
+
+ protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+ if (supportUnsafe) {
+ UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
+ } else {
+ HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 74a7db7761..6bf2f82954 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
-import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer
@@ -38,7 +38,7 @@ trait HashOuterJoin {
val left: SparkPlan
val right: SparkPlan
-override def outputPartitioning: Partitioning = joinType match {
+ override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
@@ -59,6 +59,49 @@ override def outputPartitioning: Partitioning = joinType match {
}
}
+ protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
+ case RightOuter => (left, right)
+ case LeftOuter => (right, left)
+ case x =>
+ throw new IllegalArgumentException(
+ s"HashOuterJoin should not take $x as the JoinType")
+ }
+
+ protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
+ case RightOuter => (leftKeys, rightKeys)
+ case LeftOuter => (rightKeys, leftKeys)
+ case x =>
+ throw new IllegalArgumentException(
+ s"HashOuterJoin should not take $x as the JoinType")
+ }
+
+ protected[this] def supportUnsafe: Boolean = {
+ (self.codegenEnabled && joinType != FullOuter
+ && UnsafeProjection.canSupport(buildKeys)
+ && UnsafeProjection.canSupport(self.schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = supportUnsafe
+ override def canProcessUnsafeRows: Boolean = supportUnsafe
+
+ protected[this] def streamedKeyGenerator(): Projection = {
+ if (supportUnsafe) {
+ UnsafeProjection.create(streamedKeys, streamedPlan.output)
+ } else {
+ newProjection(streamedKeys, streamedPlan.output)
+ }
+ }
+
+ @transient private[this] lazy val resultProjection: Projection = {
+ if (supportUnsafe) {
+ UnsafeProjection.create(self.schema)
+ } else {
+ new Projection {
+ override def apply(r: InternalRow): InternalRow = r
+ }
+ }
+ }
+
@transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
@transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
@@ -76,16 +119,20 @@ override def outputPartitioning: Partitioning = joinType match {
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
- val temp = rightIter.collect {
- case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
+ val temp = if (rightIter != null) {
+ rightIter.collect {
+ case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy()
+ }
+ } else {
+ List.empty
}
if (temp.isEmpty) {
- joinedRow.withRight(rightNullRow).copy :: Nil
+ resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
} else {
temp
}
} else {
- joinedRow.withRight(rightNullRow).copy :: Nil
+ resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
}
}
ret.iterator
@@ -97,17 +144,21 @@ override def outputPartitioning: Partitioning = joinType match {
joinedRow: JoinedRow): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
- val temp = leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) =>
- joinedRow.copy()
+ val temp = if (leftIter != null) {
+ leftIter.collect {
+ case l if boundCondition(joinedRow.withLeft(l)) =>
+ resultProjection(joinedRow).copy()
+ }
+ } else {
+ List.empty
}
if (temp.isEmpty) {
- joinedRow.withLeft(leftNullRow).copy :: Nil
+ resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
} else {
temp
}
} else {
- joinedRow.withLeft(leftNullRow).copy :: Nil
+ resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
}
}
ret.iterator
@@ -159,6 +210,7 @@ override def outputPartitioning: Partitioning = joinType match {
}
}
+ // This is only used by FullOuter
protected[this] def buildHashTable(
iter: Iterator[InternalRow],
keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = {
@@ -178,4 +230,12 @@ override def outputPartitioning: Partitioning = joinType match {
hashTable
}
+
+ protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+ if (supportUnsafe) {
+ UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
+ } else {
+ HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 1b983bc3a9..7f49264d40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -32,34 +32,45 @@ trait HashSemiJoin {
override def output: Seq[Attribute] = left.output
- @transient protected lazy val rightKeyGenerator: Projection =
- newProjection(rightKeys, right.output)
+ protected[this] def supportUnsafe: Boolean = {
+ (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
+ && UnsafeProjection.canSupport(rightKeys)
+ && UnsafeProjection.canSupport(left.schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = supportUnsafe
+
+ @transient protected lazy val leftKeyGenerator: Projection =
+ if (supportUnsafe) {
+ UnsafeProjection.create(leftKeys, left.output)
+ } else {
+ newMutableProjection(leftKeys, left.output)()
+ }
- @transient protected lazy val leftKeyGenerator: () => MutableProjection =
- newMutableProjection(leftKeys, left.output)
+ @transient protected lazy val rightKeyGenerator: Projection =
+ if (supportUnsafe) {
+ UnsafeProjection.create(rightKeys, right.output)
+ } else {
+ newMutableProjection(rightKeys, right.output)()
+ }
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
- protected def buildKeyHashSet(
- buildIter: Iterator[InternalRow],
- copy: Boolean): java.util.Set[InternalRow] = {
+ protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null
// Create a Hash set of buildKeys
+ val rightKey = rightKeyGenerator
while (buildIter.hasNext) {
currentRow = buildIter.next()
- val rowKey = rightKeyGenerator(currentRow)
+ val rowKey = rightKey(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
- if (copy) {
- hashSet.add(rowKey.copy())
- } else {
- // rowKey may be not serializable (from codegen)
- hashSet.add(rowKey)
- }
+ hashSet.add(rowKey.copy())
}
}
}
@@ -67,25 +78,34 @@ trait HashSemiJoin {
}
protected def hashSemiJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation): Iterator[InternalRow] = {
- val joinKeys = leftKeyGenerator()
- val joinedRow = new JoinedRow
+ streamIter: Iterator[InternalRow],
+ hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
+ val joinKeys = leftKeyGenerator
streamIter.filter(current => {
- lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue)
- !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
- (build: InternalRow) => boundCondition(joinedRow(current, build))
- }
+ val key = joinKeys(current)
+ !key.anyNull && hashSet.contains(key)
})
}
+ protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+ if (supportUnsafe) {
+ UnsafeHashedRelation(buildIter, rightKeys, right)
+ } else {
+ HashedRelation(buildIter, newProjection(rightKeys, right.output))
+ }
+ }
+
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
- hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
- val joinKeys = leftKeyGenerator()
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = leftKeyGenerator
val joinedRow = new JoinedRow
- streamIter.filter(current => {
- !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue)
- })
+ streamIter.filter { current =>
+ val key = joinKeys(current)
+ lazy val rowBuffer = hashedRelation.get(key)
+ !key.anyNull && rowBuffer != null && rowBuffer.exists {
+ (row: InternalRow) => boundCondition(joinedRow(current, row))
+ }
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 6b51f5d415..8d5731afd5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.execution.joins
-import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.util.{HashMap => JavaHashMap}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Projection
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.CompactBuffer
@@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
}
}
-
// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
@@ -148,3 +148,80 @@ private[joins] object HashedRelation {
}
}
}
+
+
+/**
+ * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
+ * sequence of values.
+ *
+ * TODO(davies): use BytesToBytesMap
+ */
+private[joins] final class UnsafeHashedRelation(
+ private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
+ extends HashedRelation with Externalizable {
+
+ def this() = this(null) // Needed for serialization
+
+ override def get(key: InternalRow): CompactBuffer[InternalRow] = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ // Thanks to type eraser
+ hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ }
+}
+
+private[joins] object UnsafeHashedRelation {
+
+ def apply(
+ input: Iterator[InternalRow],
+ buildKeys: Seq[Expression],
+ buildPlan: SparkPlan,
+ sizeEstimate: Int = 64): HashedRelation = {
+ val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output))
+ apply(input, boundedKeys, buildPlan.schema, sizeEstimate)
+ }
+
+ // Used for tests
+ def apply(
+ input: Iterator[InternalRow],
+ buildKeys: Seq[Expression],
+ rowSchema: StructType,
+ sizeEstimate: Int): HashedRelation = {
+
+ // TODO: Use BytesToBytesMap.
+ val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val toUnsafe = UnsafeProjection.create(rowSchema)
+ val keyGenerator = UnsafeProjection.create(buildKeys)
+
+ // Create a mapping of buildKeys -> rows
+ while (input.hasNext) {
+ val currentRow = input.next()
+ val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
+ currentRow.asInstanceOf[UnsafeRow]
+ } else {
+ toUnsafe(currentRow)
+ }
+ val rowKey = keyGenerator(unsafeRow)
+ if (!rowKey.anyNull) {
+ val existingMatchList = hashTable.get(rowKey)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new CompactBuffer[UnsafeRow]()
+ hashTable.put(rowKey.copy(), newMatchList)
+ newMatchList
+ } else {
+ existingMatchList
+ }
+ matchList += unsafeRow.copy()
+ }
+ }
+
+ new UnsafeHashedRelation(hashTable)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index db5be9f453..4443455ef1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -39,6 +39,9 @@ case class LeftSemiJoinBNL(
override def output: Seq[Attribute] = left.output
+ override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+
/** The Streamed Relation */
override def left: SparkPlan = streamed
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 9eaac817d9..874712a4e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -43,10 +43,10 @@ case class LeftSemiJoinHash(
protected override def doExecute(): RDD[InternalRow] = {
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
- val hashSet = buildKeyHashSet(buildIter, copy = false)
+ val hashSet = buildKeyHashSet(buildIter)
hashSemiJoin(streamIter, hashSet)
} else {
- val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+ val hashRelation = buildHashRelation(buildIter)
hashSemiJoin(streamIter, hashRelation)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5439e10a60..948d0ccebc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -45,7 +45,7 @@ case class ShuffledHashJoin(
protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
+ val hashed = buildHashRelation(buildIter)
hashJoin(streamIter, hashed)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
index ab0a6ad56a..f54f1edd38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -50,24 +50,25 @@ case class ShuffledHashOuterJoin(
// TODO this probably can be replaced by external sort (sort merged join?)
joinType match {
case LeftOuter =>
- val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
- val keyGenerator = newProjection(leftKeys, left.output)
+ val hashed = buildHashRelation(rightIter)
+ val keyGenerator = streamedKeyGenerator()
leftIter.flatMap( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST))
+ leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey))
})
case RightOuter =>
- val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
- val keyGenerator = newProjection(rightKeys, right.output)
+ val hashed = buildHashRelation(leftIter)
+ val keyGenerator = streamedKeyGenerator()
rightIter.flatMap ( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+ rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow)
})
case FullOuter =>
+ // TODO(davies): use UnsafeRow
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
(leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
index 421d510e67..29f3beb3cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
@@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule
*/
@DeveloperApi
case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
+
+ require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe")
+
override def output: Seq[Attribute] = child.output
override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = false
@@ -93,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] {
}
case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) =>
if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) {
- // If this operator's children produce both unsafe and safe rows, then convert everything
- // to unsafe rows
- operator.withNewChildren {
- operator.children.map {
- c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+ // If this operator's children produce both unsafe and safe rows,
+ // convert everything unsafe rows if all the schema of them are support by UnsafeRow
+ if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) {
+ operator.withNewChildren {
+ operator.children.map {
+ c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+ }
+ }
+ } else {
+ operator.withNewChildren {
+ operator.children.map {
+ c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c
+ }
}
}
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 3854dc1b7a..d36e263937 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.memory.MemoryAllocator
import org.apache.spark.unsafe.types.UTF8String
@@ -31,7 +31,7 @@ class UnsafeRowSuite extends SparkFunSuite {
test("writeToStream") {
val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123)
val arrayBackedUnsafeRow: UnsafeRow =
- UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row)
+ UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row)
assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]])
val bytesFromArrayBackedRow: Array[Byte] = {
val baos = new ByteArrayOutputStream()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 9d9858b1c6..9dd2220f09 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Projection
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
import org.apache.spark.util.collection.CompactBuffer
@@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite {
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])
- assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
- assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
+ assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
+ assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
assert(hashed.get(InternalRow(10)) === null)
val data2 = CompactBuffer[InternalRow](data(2))
data2 += data(2)
- assert(hashed.get(data(2)) == data2)
+ assert(hashed.get(data(2)) === data2)
}
test("UniqueKeyHashedRelation") {
@@ -49,15 +51,40 @@ class HashedRelationSuite extends SparkFunSuite {
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
- assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
- assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
- assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2)))
+ assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
+ assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
+ assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2)))
assert(hashed.get(InternalRow(10)) === null)
val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
- assert(uniqHashed.getValue(data(0)) == data(0))
- assert(uniqHashed.getValue(data(1)) == data(1))
- assert(uniqHashed.getValue(data(2)) == data(2))
- assert(uniqHashed.getValue(InternalRow(10)) == null)
+ assert(uniqHashed.getValue(data(0)) === data(0))
+ assert(uniqHashed.getValue(data(1)) === data(1))
+ assert(uniqHashed.getValue(data(2)) === data(2))
+ assert(uniqHashed.getValue(InternalRow(10)) === null)
+ }
+
+ test("UnsafeHashedRelation") {
+ val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
+ val buildKey = Seq(BoundReference(0, IntegerType, false))
+ val schema = StructType(StructField("a", IntegerType, true) :: Nil)
+ val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1)
+ assert(hashed.isInstanceOf[UnsafeHashedRelation])
+
+ val toUnsafeKey = UnsafeProjection.create(schema)
+ val unsafeData = data.map(toUnsafeKey(_).copy()).toArray
+ assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
+ assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed.get(toUnsafeKey(InternalRow(10))) === null)
+
+ val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
+ data2 += unsafeData(2).copy()
+ assert(hashed.get(unsafeData(2)) === data2)
+
+ val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
+ .asInstanceOf[UnsafeHashedRelation]
+ assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
+ assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)
+ assert(hashed2.get(unsafeData(2)) === data2)
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
index 85cd02469a..61f483ced3 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
@@ -44,12 +44,16 @@ public final class Murmur3_x86_32 {
return fmix(h1, 4);
}
- public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) {
+ public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
+ return hashUnsafeWords(base, offset, lengthInBytes, seed);
+ }
+
+ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
// This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
int h1 = seed;
- for (int offset = 0; offset < lengthInBytes; offset += 4) {
- int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
+ for (int i = 0; i < lengthInBytes; i += 4) {
+ int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i);
int k1 = mixK1(halfWord);
h1 = mixH1(h1, k1);
}