aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorXiu Guo <xguo27@gmail.com>2016-02-22 16:34:02 -0800
committerReynold Xin <rxin@databricks.com>2016-02-22 16:34:02 -0800
commit2063781840831469b394313694bfd25cbde2bb1e (patch)
tree42395c9d9e7584a25166fff8d9e608c8a762762d /sql
parent173aa949c309ff7a7a03e9d762b9108542219a95 (diff)
downloadspark-2063781840831469b394313694bfd25cbde2bb1e.tar.gz
spark-2063781840831469b394313694bfd25cbde2bb1e.tar.bz2
spark-2063781840831469b394313694bfd25cbde2bb1e.zip
[SPARK-13422][SQL] Use HashedRelation instead of HashSet in Left Semi Joins
Use the HashedRelation which is a more optimized datastructure and reduce code complexity Author: Xiu Guo <xguo27@gmail.com> Closes #11291 from xguo27/SPARK-13422.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala55
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala13
3 files changed, 14 insertions, 81 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 1f99fbedde..d3bcfad7c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
+ * Build the right table's join keys into a HashedRelation, and iteratively go through the left
+ * table, to find if the join keys are in the HashedRelation.
*/
case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
@@ -40,29 +40,18 @@ case class BroadcastLeftSemiJoinHash(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
override def requiredChildDistribution: Seq[Distribution] = {
- val mode = if (condition.isEmpty) {
- HashSetBroadcastMode(rightKeys, right.output)
- } else {
- HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
- }
+ val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- if (condition.isEmpty) {
- val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]()
- left.execute().mapPartitionsInternal { streamIter =>
- hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
- }
- } else {
- val broadcastedRelation = right.executeBroadcast[HashedRelation]()
- left.execute().mapPartitionsInternal { streamIter =>
- val hashedRelation = broadcastedRelation.value
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
- hashSemiJoin(streamIter, hashedRelation, numOutputRows)
- }
+ val broadcastedRelation = right.executeBroadcast[HashedRelation]()
+ left.execute().mapPartitionsInternal { streamIter =>
+ val hashedRelation = broadcastedRelation.value
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
+ hashSemiJoin(streamIter, hashedRelation, numOutputRows)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 1cb6a00617..3eed6e3e11 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -43,24 +43,6 @@ trait HashSemiJoin {
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
- protected def buildKeyHashSet(
- buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
- HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
- }
-
- protected def hashSemiJoin(
- streamIter: Iterator[InternalRow],
- hashSet: java.util.Set[InternalRow],
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val joinKeys = leftKeyGenerator
- streamIter.filter(current => {
- val key = joinKeys(current)
- val r = !key.anyNull && hashSet.contains(key)
- if (r) numOutputRows += 1
- r
- })
- }
-
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation,
@@ -70,44 +52,11 @@ trait HashSemiJoin {
streamIter.filter { current =>
val key = joinKeys(current)
lazy val rowBuffer = hashedRelation.get(key)
- val r = !key.anyNull && rowBuffer != null && rowBuffer.exists {
+ val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
- }
+ })
if (r) numOutputRows += 1
r
}
}
}
-
-private[execution] object HashSemiJoin {
- def buildKeyHashSet(
- keys: Seq[Expression],
- attributes: Seq[Attribute],
- rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
- val hashSet = new java.util.HashSet[InternalRow]()
-
- // Create a Hash set of buildKeys
- val key = UnsafeProjection.create(keys, attributes)
- while (rows.hasNext) {
- val currentRow = rows.next()
- val rowKey = key(currentRow)
- if (!rowKey.anyNull) {
- val keyExists = hashSet.contains(rowKey)
- if (!keyExists) {
- hashSet.add(rowKey.copy())
- }
- }
- }
- hashSet
- }
-}
-
-/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */
-private[execution] case class HashSetBroadcastMode(
- keys: Seq[Expression],
- attributes: Seq[Attribute]) extends BroadcastMode {
-
- override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = {
- HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index d8d3045ccf..242ed61232 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
+ * Build the right table's join keys into a HashedRelation, and iteratively go through the left
+ * table, to find if the join keys are in the HashedRelation.
*/
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
@@ -47,13 +47,8 @@ case class LeftSemiJoinHash(
val numOutputRows = longMetric("numOutputRows")
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
- if (condition.isEmpty) {
- val hashSet = buildKeyHashSet(buildIter)
- hashSemiJoin(streamIter, hashSet, numOutputRows)
- } else {
- val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
- hashSemiJoin(streamIter, hashRelation, numOutputRows)
- }
+ val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+ hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
}
}