aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-22 13:02:43 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-22 13:02:43 -0700
commite0b7ba59a1ace9b78a1ad6f3f07fe153db20b52c (patch)
tree539e14cbb49b30181461e7e01ca0056a5f1fe935
parent86f80e2b4759e574fe3eb91695f81b644db87242 (diff)
downloadspark-e0b7ba59a1ace9b78a1ad6f3f07fe153db20b52c.tar.gz
spark-e0b7ba59a1ace9b78a1ad6f3f07fe153db20b52c.tar.bz2
spark-e0b7ba59a1ace9b78a1ad6f3f07fe153db20b52c.zip
[SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin
This PR introduce unsafe version (using UnsafeRow) of HashJoin, HashOuterJoin and HashSemiJoin, including the broadcast one and shuffle one (except FullOuterJoin, which is better to be implemented using SortMergeJoin). It use HashMap to store UnsafeRow right now, will change to use BytesToBytesMap for better performance (in another PR). Author: Davies Liu <davies@databricks.com> Closes #7480 from davies/unsafe_join and squashes the following commits: 6294b1e [Davies Liu] fix projection 10583f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join dede020 [Davies Liu] fix test 84c9807 [Davies Liu] address comments a05b4f6 [Davies Liu] support UnsafeRow in LeftSemiJoinBNL and BroadcastNestedLoopJoin 611d2ed [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 9481ae8 [Davies Liu] return UnsafeRow after join() ca2b40f [Davies Liu] revert unrelated change 68f5cd9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 0f4380d [Davies Liu] ada a comment 69e38f5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 1a40f02 [Davies Liu] refactor ab1690f [Davies Liu] address comments 60371f2 [Davies Liu] use UnsafeRow in SemiJoin a6c0b7d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 184b852 [Davies Liu] fix style 6acbb11 [Davies Liu] fix tests 95d0762 [Davies Liu] remove println bea4a50 [Davies Liu] Unsafe HashJoin
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java50
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala82
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala74
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala85
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala49
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java10
20 files changed, 444 insertions, 135 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6ce03a48e9..7f08bf7b74 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions;
import java.io.IOException;
import java.io.OutputStream;
-import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.UTF8String;
@@ -354,7 +355,7 @@ public final class UnsafeRow extends MutableRow {
* This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
- public InternalRow copy() {
+ public UnsafeRow copy() {
if (pool != null) {
throw new UnsupportedOperationException(
"Copy is not supported for UnsafeRows that use object pools");
@@ -405,7 +406,50 @@ public final class UnsafeRow extends MutableRow {
}
@Override
+ public int hashCode() {
+ return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof UnsafeRow) {
+ UnsafeRow o = (UnsafeRow) other;
+ return (sizeInBytes == o.sizeInBytes) &&
+ ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
+ sizeInBytes);
+ }
+ return false;
+ }
+
+ /**
+ * Returns the underlying bytes for this UnsafeRow.
+ */
+ public byte[] getBytes() {
+ if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
+ && (((byte[]) baseObject).length == sizeInBytes)) {
+ return (byte[]) baseObject;
+ } else {
+ byte[] bytes = new byte[sizeInBytes];
+ PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
+ return bytes;
+ }
+ }
+
+ // This is for debugging
+ @Override
+ public String toString() {
+ StringBuilder build = new StringBuilder("[");
+ for (int i = 0; i < sizeInBytes; i += 8) {
+ build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
+ build.append(',');
+ }
+ build.append(']');
+ return build.toString();
+ }
+
+ @Override
public boolean anyNull() {
- return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
+ return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index d1d81c87bb..39fd6e1bc6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -28,11 +28,10 @@ import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.sql.AbstractScalaRowIterator;
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
@@ -176,12 +175,7 @@ final class UnsafeExternalRowSorter {
*/
public static boolean supportsSchema(StructType schema) {
// TODO: add spilling note to explain why we do this for now:
- for (StructField field : schema.fields()) {
- if (!UnsafeColumnWriter.canEmbed(field.dataType())) {
- return false;
- }
- }
- return true;
+ return UnsafeProjection.canSupport(schema);
}
private static final class RowComparator extends RecordComparator {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index b10a3c8774..4a13b687bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._
/**
@@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def toString: String = s"input[$ordinal, $dataType]"
- override def eval(input: InternalRow): Any = input(ordinal)
+ // Use special getter for primitive types (for UnsafeRow)
+ override def eval(input: InternalRow): Any = {
+ if (input.isNullAt(ordinal)) {
+ null
+ } else {
+ dataType match {
+ case BooleanType => input.getBoolean(ordinal)
+ case ByteType => input.getByte(ordinal)
+ case ShortType => input.getShort(ordinal)
+ case IntegerType | DateType => input.getInt(ordinal)
+ case LongType | TimestampType => input.getLong(ordinal)
+ case FloatType => input.getFloat(ordinal)
+ case DoubleType => input.getDouble(ordinal)
+ case _ => input.get(ordinal)
+ }
+ }
+ }
override def name: String = s"i[$ordinal]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 24b01ea551..69758e653e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -83,12 +83,42 @@ abstract class UnsafeProjection extends Projection {
}
object UnsafeProjection {
+
+ /*
+ * Returns whether UnsafeProjection can support given StructType, Array[DataType] or
+ * Seq[Expression].
+ */
+ def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType))
+ def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_))
+ def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray)
+
+ /**
+ * Returns an UnsafeProjection for given StructType.
+ */
def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))
- def create(fields: Seq[DataType]): UnsafeProjection = {
+ /**
+ * Returns an UnsafeProjection for given Array of DataTypes.
+ */
+ def create(fields: Array[DataType]): UnsafeProjection = {
val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
+ create(exprs)
+ }
+
+ /**
+ * Returns an UnsafeProjection for given sequence of Expressions (bounded).
+ */
+ def create(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
}
+
+ /**
+ * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to
+ * `inputSchema`.
+ */
+ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
+ create(exprs.map(BindReferences.bindReference(_, inputSchema)))
+ }
}
/**
@@ -96,6 +126,8 @@ object UnsafeProjection {
*/
case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {
+ def this(schema: StructType) = this(schema.fields.map(_.dataType))
+
private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
new BoundReference(idx, dt, true)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 7ffdce60d2..abaa4a6ce8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -62,7 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
+ val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index ab757fc7de..c9d1a880f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.execution.joins
+import scala.concurrent._
+import scala.concurrent.duration._
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils
-import scala.collection.JavaConversions._
-import scala.concurrent._
-import scala.concurrent.duration._
-
/**
* :: DeveloperApi ::
* Performs a outer hash join for two child relations. When the output RDD of this operator is
@@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
- private[this] lazy val (buildPlan, streamedPlan) = joinType match {
- case RightOuter => (left, right)
- case LeftOuter => (right, left)
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashOuterJoin should not take $x as the JoinType")
- }
-
- private[this] lazy val (buildKeys, streamedKeys) = joinType match {
- case RightOuter => (leftKeys, rightKeys)
- case LeftOuter => (rightKeys, leftKeys)
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashOuterJoin should not take $x as the JoinType")
- }
-
@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- // buildHashTable uses code-generated rows as keys, which are not serializable
- val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))
+ val hashed = buildHashRelation(input.iterator)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
@@ -89,21 +71,21 @@ case class BroadcastHashOuterJoin(
streamedPlan.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
- val keyGenerator = newProjection(streamedKeys, streamedPlan.output)
+ val keyGenerator = streamedKeyGenerator
joinType match {
case LeftOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
+ leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
})
case RightOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+ rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
})
case x =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 2750f58b00..f71c0ce352 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -40,15 +40,14 @@ case class BroadcastLeftSemiJoinHash(
val buildIter = right.execute().map(_.copy()).collect().toIterator
if (condition.isEmpty) {
- // rowKey may be not serializable (from codegen)
- val hashSet = buildKeyHashSet(buildIter, copy = true)
+ val hashSet = buildKeyHashSet(buildIter)
val broadcastedRelation = sparkContext.broadcast(hashSet)
left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
- val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+ val hashRelation = buildHashRelation(buildIter)
val broadcastedRelation = sparkContext.broadcast(hashRelation)
left.execute().mapPartitions { streamIter =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 60b4266fad..700636966f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -44,6 +44,19 @@ case class BroadcastNestedLoopJoin(
case BuildLeft => (right, left)
}
+ override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+
+ @transient private[this] lazy val resultProjection: Projection = {
+ if (outputsUnsafeRows) {
+ UnsafeProjection.create(schema)
+ } else {
+ new Projection {
+ override def apply(r: InternalRow): InternalRow = r
+ }
+ }
+ }
+
override def outputPartitioning: Partitioning = streamed.outputPartitioning
override def output: Seq[Attribute] = {
@@ -74,6 +87,7 @@ case class BroadcastNestedLoopJoin(
val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow
+
val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
@@ -86,11 +100,11 @@ case class BroadcastNestedLoopJoin(
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
- matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
+ matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
- matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
+ matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case _ =>
@@ -100,9 +114,9 @@ case class BroadcastNestedLoopJoin(
(streamRowMatched, joinType, buildSide) match {
case (false, LeftOuter | FullOuter, BuildRight) =>
- matchedRows += joinedRow(streamedRow, rightNulls).copy()
+ matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy()
case (false, RightOuter | FullOuter, BuildLeft) =>
- matchedRows += joinedRow(leftNulls, streamedRow).copy()
+ matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy()
case _ =>
}
}
@@ -110,12 +124,9 @@ case class BroadcastNestedLoopJoin(
}
val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
- val allIncludedBroadcastTuples =
- if (includedBroadcastTuples.count == 0) {
- new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
- } else {
- includedBroadcastTuples.reduce(_ ++ _)
- }
+ val allIncludedBroadcastTuples = includedBroadcastTuples.fold(
+ new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
+ )(_ ++ _)
val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
@@ -127,8 +138,10 @@ case class BroadcastNestedLoopJoin(
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
- case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
- case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
+ case (RightOuter | FullOuter, BuildRight) =>
+ buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
+ case (LeftOuter | FullOuter, BuildLeft) =>
+ buf += resultProjection(new JoinedRow(rel(i), rightNulls))
case _ =>
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index ff85ea3f6a..ae34409bcf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -44,11 +44,20 @@ trait HashJoin {
override def output: Seq[Attribute] = left.output ++ right.output
- @transient protected lazy val buildSideKeyGenerator: Projection =
- newProjection(buildKeys, buildPlan.output)
+ protected[this] def supportUnsafe: Boolean = {
+ (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
+ && UnsafeProjection.canSupport(self.schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = supportUnsafe
+ override def canProcessUnsafeRows: Boolean = supportUnsafe
- @transient protected lazy val streamSideKeyGenerator: () => MutableProjection =
- newMutableProjection(streamedKeys, streamedPlan.output)
+ @transient protected lazy val streamSideKeyGenerator: Projection =
+ if (supportUnsafe) {
+ UnsafeProjection.create(streamedKeys, streamedPlan.output)
+ } else {
+ newMutableProjection(streamedKeys, streamedPlan.output)()
+ }
protected def hashJoin(
streamIter: Iterator[InternalRow],
@@ -61,8 +70,17 @@ trait HashJoin {
// Mutable per row objects.
private[this] val joinRow = new JoinedRow2
+ private[this] val resultProjection: Projection = {
+ if (supportUnsafe) {
+ UnsafeProjection.create(self.schema)
+ } else {
+ new Projection {
+ override def apply(r: InternalRow): InternalRow = r
+ }
+ }
+ }
- private[this] val joinKeys = streamSideKeyGenerator()
+ private[this] val joinKeys = streamSideKeyGenerator
override final def hasNext: Boolean =
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
@@ -74,7 +92,7 @@ trait HashJoin {
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
currentMatchPosition += 1
- ret
+ resultProjection(ret)
}
/**
@@ -89,8 +107,9 @@ trait HashJoin {
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
- if (!joinKeys(currentStreamedRow).anyNull) {
- currentHashMatches = hashedRelation.get(joinKeys.currentValue)
+ val key = joinKeys(currentStreamedRow)
+ if (!key.anyNull) {
+ currentHashMatches = hashedRelation.get(key)
}
}
@@ -103,4 +122,12 @@ trait HashJoin {
}
}
}
+
+ protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+ if (supportUnsafe) {
+ UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
+ } else {
+ HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 74a7db7761..6bf2f82954 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
-import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer
@@ -38,7 +38,7 @@ trait HashOuterJoin {
val left: SparkPlan
val right: SparkPlan
-override def outputPartitioning: Partitioning = joinType match {
+ override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
@@ -59,6 +59,49 @@ override def outputPartitioning: Partitioning = joinType match {
}
}
+ protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
+ case RightOuter => (left, right)
+ case LeftOuter => (right, left)
+ case x =>
+ throw new IllegalArgumentException(
+ s"HashOuterJoin should not take $x as the JoinType")
+ }
+
+ protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
+ case RightOuter => (leftKeys, rightKeys)
+ case LeftOuter => (rightKeys, leftKeys)
+ case x =>
+ throw new IllegalArgumentException(
+ s"HashOuterJoin should not take $x as the JoinType")
+ }
+
+ protected[this] def supportUnsafe: Boolean = {
+ (self.codegenEnabled && joinType != FullOuter
+ && UnsafeProjection.canSupport(buildKeys)
+ && UnsafeProjection.canSupport(self.schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = supportUnsafe
+ override def canProcessUnsafeRows: Boolean = supportUnsafe
+
+ protected[this] def streamedKeyGenerator(): Projection = {
+ if (supportUnsafe) {
+ UnsafeProjection.create(streamedKeys, streamedPlan.output)
+ } else {
+ newProjection(streamedKeys, streamedPlan.output)
+ }
+ }
+
+ @transient private[this] lazy val resultProjection: Projection = {
+ if (supportUnsafe) {
+ UnsafeProjection.create(self.schema)
+ } else {
+ new Projection {
+ override def apply(r: InternalRow): InternalRow = r
+ }
+ }
+ }
+
@transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
@transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
@@ -76,16 +119,20 @@ override def outputPartitioning: Partitioning = joinType match {
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
- val temp = rightIter.collect {
- case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
+ val temp = if (rightIter != null) {
+ rightIter.collect {
+ case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy()
+ }
+ } else {
+ List.empty
}
if (temp.isEmpty) {
- joinedRow.withRight(rightNullRow).copy :: Nil
+ resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
} else {
temp
}
} else {
- joinedRow.withRight(rightNullRow).copy :: Nil
+ resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
}
}
ret.iterator
@@ -97,17 +144,21 @@ override def outputPartitioning: Partitioning = joinType match {
joinedRow: JoinedRow): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
- val temp = leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) =>
- joinedRow.copy()
+ val temp = if (leftIter != null) {
+ leftIter.collect {
+ case l if boundCondition(joinedRow.withLeft(l)) =>
+ resultProjection(joinedRow).copy()
+ }
+ } else {
+ List.empty
}
if (temp.isEmpty) {
- joinedRow.withLeft(leftNullRow).copy :: Nil
+ resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
} else {
temp
}
} else {
- joinedRow.withLeft(leftNullRow).copy :: Nil
+ resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
}
}
ret.iterator
@@ -159,6 +210,7 @@ override def outputPartitioning: Partitioning = joinType match {
}
}
+ // This is only used by FullOuter
protected[this] def buildHashTable(
iter: Iterator[InternalRow],
keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = {
@@ -178,4 +230,12 @@ override def outputPartitioning: Partitioning = joinType match {
hashTable
}
+
+ protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+ if (supportUnsafe) {
+ UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
+ } else {
+ HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 1b983bc3a9..7f49264d40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -32,34 +32,45 @@ trait HashSemiJoin {
override def output: Seq[Attribute] = left.output
- @transient protected lazy val rightKeyGenerator: Projection =
- newProjection(rightKeys, right.output)
+ protected[this] def supportUnsafe: Boolean = {
+ (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
+ && UnsafeProjection.canSupport(rightKeys)
+ && UnsafeProjection.canSupport(left.schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = supportUnsafe
+
+ @transient protected lazy val leftKeyGenerator: Projection =
+ if (supportUnsafe) {
+ UnsafeProjection.create(leftKeys, left.output)
+ } else {
+ newMutableProjection(leftKeys, left.output)()
+ }
- @transient protected lazy val leftKeyGenerator: () => MutableProjection =
- newMutableProjection(leftKeys, left.output)
+ @transient protected lazy val rightKeyGenerator: Projection =
+ if (supportUnsafe) {
+ UnsafeProjection.create(rightKeys, right.output)
+ } else {
+ newMutableProjection(rightKeys, right.output)()
+ }
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
- protected def buildKeyHashSet(
- buildIter: Iterator[InternalRow],
- copy: Boolean): java.util.Set[InternalRow] = {
+ protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null
// Create a Hash set of buildKeys
+ val rightKey = rightKeyGenerator
while (buildIter.hasNext) {
currentRow = buildIter.next()
- val rowKey = rightKeyGenerator(currentRow)
+ val rowKey = rightKey(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
- if (copy) {
- hashSet.add(rowKey.copy())
- } else {
- // rowKey may be not serializable (from codegen)
- hashSet.add(rowKey)
- }
+ hashSet.add(rowKey.copy())
}
}
}
@@ -67,25 +78,34 @@ trait HashSemiJoin {
}
protected def hashSemiJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation): Iterator[InternalRow] = {
- val joinKeys = leftKeyGenerator()
- val joinedRow = new JoinedRow
+ streamIter: Iterator[InternalRow],
+ hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
+ val joinKeys = leftKeyGenerator
streamIter.filter(current => {
- lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue)
- !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
- (build: InternalRow) => boundCondition(joinedRow(current, build))
- }
+ val key = joinKeys(current)
+ !key.anyNull && hashSet.contains(key)
})
}
+ protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+ if (supportUnsafe) {
+ UnsafeHashedRelation(buildIter, rightKeys, right)
+ } else {
+ HashedRelation(buildIter, newProjection(rightKeys, right.output))
+ }
+ }
+
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
- hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
- val joinKeys = leftKeyGenerator()
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = leftKeyGenerator
val joinedRow = new JoinedRow
- streamIter.filter(current => {
- !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue)
- })
+ streamIter.filter { current =>
+ val key = joinKeys(current)
+ lazy val rowBuffer = hashedRelation.get(key)
+ !key.anyNull && rowBuffer != null && rowBuffer.exists {
+ (row: InternalRow) => boundCondition(joinedRow(current, row))
+ }
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 6b51f5d415..8d5731afd5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.execution.joins
-import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.util.{HashMap => JavaHashMap}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Projection
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.CompactBuffer
@@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
}
}
-
// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
@@ -148,3 +148,80 @@ private[joins] object HashedRelation {
}
}
}
+
+
+/**
+ * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
+ * sequence of values.
+ *
+ * TODO(davies): use BytesToBytesMap
+ */
+private[joins] final class UnsafeHashedRelation(
+ private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
+ extends HashedRelation with Externalizable {
+
+ def this() = this(null) // Needed for serialization
+
+ override def get(key: InternalRow): CompactBuffer[InternalRow] = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ // Thanks to type eraser
+ hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ }
+}
+
+private[joins] object UnsafeHashedRelation {
+
+ def apply(
+ input: Iterator[InternalRow],
+ buildKeys: Seq[Expression],
+ buildPlan: SparkPlan,
+ sizeEstimate: Int = 64): HashedRelation = {
+ val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output))
+ apply(input, boundedKeys, buildPlan.schema, sizeEstimate)
+ }
+
+ // Used for tests
+ def apply(
+ input: Iterator[InternalRow],
+ buildKeys: Seq[Expression],
+ rowSchema: StructType,
+ sizeEstimate: Int): HashedRelation = {
+
+ // TODO: Use BytesToBytesMap.
+ val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val toUnsafe = UnsafeProjection.create(rowSchema)
+ val keyGenerator = UnsafeProjection.create(buildKeys)
+
+ // Create a mapping of buildKeys -> rows
+ while (input.hasNext) {
+ val currentRow = input.next()
+ val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
+ currentRow.asInstanceOf[UnsafeRow]
+ } else {
+ toUnsafe(currentRow)
+ }
+ val rowKey = keyGenerator(unsafeRow)
+ if (!rowKey.anyNull) {
+ val existingMatchList = hashTable.get(rowKey)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new CompactBuffer[UnsafeRow]()
+ hashTable.put(rowKey.copy(), newMatchList)
+ newMatchList
+ } else {
+ existingMatchList
+ }
+ matchList += unsafeRow.copy()
+ }
+ }
+
+ new UnsafeHashedRelation(hashTable)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index db5be9f453..4443455ef1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -39,6 +39,9 @@ case class LeftSemiJoinBNL(
override def output: Seq[Attribute] = left.output
+ override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+
/** The Streamed Relation */
override def left: SparkPlan = streamed
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 9eaac817d9..874712a4e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -43,10 +43,10 @@ case class LeftSemiJoinHash(
protected override def doExecute(): RDD[InternalRow] = {
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
- val hashSet = buildKeyHashSet(buildIter, copy = false)
+ val hashSet = buildKeyHashSet(buildIter)
hashSemiJoin(streamIter, hashSet)
} else {
- val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+ val hashRelation = buildHashRelation(buildIter)
hashSemiJoin(streamIter, hashRelation)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5439e10a60..948d0ccebc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -45,7 +45,7 @@ case class ShuffledHashJoin(
protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
+ val hashed = buildHashRelation(buildIter)
hashJoin(streamIter, hashed)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
index ab0a6ad56a..f54f1edd38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -50,24 +50,25 @@ case class ShuffledHashOuterJoin(
// TODO this probably can be replaced by external sort (sort merged join?)
joinType match {
case LeftOuter =>
- val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
- val keyGenerator = newProjection(leftKeys, left.output)
+ val hashed = buildHashRelation(rightIter)
+ val keyGenerator = streamedKeyGenerator()
leftIter.flatMap( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST))
+ leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey))
})
case RightOuter =>
- val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
- val keyGenerator = newProjection(rightKeys, right.output)
+ val hashed = buildHashRelation(leftIter)
+ val keyGenerator = streamedKeyGenerator()
rightIter.flatMap ( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+ rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow)
})
case FullOuter =>
+ // TODO(davies): use UnsafeRow
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
(leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
index 421d510e67..29f3beb3cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
@@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule
*/
@DeveloperApi
case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
+
+ require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe")
+
override def output: Seq[Attribute] = child.output
override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = false
@@ -93,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] {
}
case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) =>
if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) {
- // If this operator's children produce both unsafe and safe rows, then convert everything
- // to unsafe rows
- operator.withNewChildren {
- operator.children.map {
- c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+ // If this operator's children produce both unsafe and safe rows,
+ // convert everything unsafe rows if all the schema of them are support by UnsafeRow
+ if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) {
+ operator.withNewChildren {
+ operator.children.map {
+ c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+ }
+ }
+ } else {
+ operator.withNewChildren {
+ operator.children.map {
+ c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c
+ }
}
}
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 3854dc1b7a..d36e263937 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.memory.MemoryAllocator
import org.apache.spark.unsafe.types.UTF8String
@@ -31,7 +31,7 @@ class UnsafeRowSuite extends SparkFunSuite {
test("writeToStream") {
val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123)
val arrayBackedUnsafeRow: UnsafeRow =
- UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row)
+ UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row)
assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]])
val bytesFromArrayBackedRow: Array[Byte] = {
val baos = new ByteArrayOutputStream()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 9d9858b1c6..9dd2220f09 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Projection
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
import org.apache.spark.util.collection.CompactBuffer
@@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite {
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])
- assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
- assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
+ assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
+ assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
assert(hashed.get(InternalRow(10)) === null)
val data2 = CompactBuffer[InternalRow](data(2))
data2 += data(2)
- assert(hashed.get(data(2)) == data2)
+ assert(hashed.get(data(2)) === data2)
}
test("UniqueKeyHashedRelation") {
@@ -49,15 +51,40 @@ class HashedRelationSuite extends SparkFunSuite {
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
- assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
- assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
- assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2)))
+ assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
+ assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
+ assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2)))
assert(hashed.get(InternalRow(10)) === null)
val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
- assert(uniqHashed.getValue(data(0)) == data(0))
- assert(uniqHashed.getValue(data(1)) == data(1))
- assert(uniqHashed.getValue(data(2)) == data(2))
- assert(uniqHashed.getValue(InternalRow(10)) == null)
+ assert(uniqHashed.getValue(data(0)) === data(0))
+ assert(uniqHashed.getValue(data(1)) === data(1))
+ assert(uniqHashed.getValue(data(2)) === data(2))
+ assert(uniqHashed.getValue(InternalRow(10)) === null)
+ }
+
+ test("UnsafeHashedRelation") {
+ val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
+ val buildKey = Seq(BoundReference(0, IntegerType, false))
+ val schema = StructType(StructField("a", IntegerType, true) :: Nil)
+ val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1)
+ assert(hashed.isInstanceOf[UnsafeHashedRelation])
+
+ val toUnsafeKey = UnsafeProjection.create(schema)
+ val unsafeData = data.map(toUnsafeKey(_).copy()).toArray
+ assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
+ assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed.get(toUnsafeKey(InternalRow(10))) === null)
+
+ val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
+ data2 += unsafeData(2).copy()
+ assert(hashed.get(unsafeData(2)) === data2)
+
+ val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
+ .asInstanceOf[UnsafeHashedRelation]
+ assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
+ assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)
+ assert(hashed2.get(unsafeData(2)) === data2)
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
index 85cd02469a..61f483ced3 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
@@ -44,12 +44,16 @@ public final class Murmur3_x86_32 {
return fmix(h1, 4);
}
- public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) {
+ public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
+ return hashUnsafeWords(base, offset, lengthInBytes, seed);
+ }
+
+ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
// This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
int h1 = seed;
- for (int offset = 0; offset < lengthInBytes; offset += 4) {
- int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
+ for (int i = 0; i < lengthInBytes; i += 4) {
+ int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i);
int k1 = mixK1(halfWord);
h1 = mixH1(h1, k1);
}