aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-10-08 18:17:01 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-08 18:17:01 -0700
commitbcb1ae049b447c37418747e0a262f54f9fc1664a (patch)
treec29b2f8dcfe52cfd93d1ccb77f9fea3b352b8dee /sql/core/src/main
parent3e4f09d2fce9dcf45eaaca827f2cf15c9d4a6c75 (diff)
downloadspark-bcb1ae049b447c37418747e0a262f54f9fc1664a.tar.gz
spark-bcb1ae049b447c37418747e0a262f54f9fc1664a.tar.bz2
spark-bcb1ae049b447c37418747e0a262f54f9fc1664a.zip
[SPARK-3857] Create joins package for various join operators.
Author: Reynold Xin <rxin@apache.org> Closes #2719 from rxin/sql-join-break and squashes the following commits: 0c0082b [Reynold Xin] Fix line length. cbc664c [Reynold Xin] Rename join -> joins package. a070d44 [Reynold Xin] Fix line length in HashJoin a39be8c [Reynold Xin] [SPARK-3857] Create a join package for various join operators.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala624
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala62
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala144
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala123
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala222
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala73
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala67
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala37
11 files changed, 839 insertions, 643 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 883f2ff521..bbf17b9fad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.parquet._
+
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
@@ -34,13 +35,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
- val semiJoin = execution.LeftSemiJoinHash(
+ val semiJoin = joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
- execution.LeftSemiJoinBNL(
- planLater(left), planLater(right), condition) :: Nil
+ joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
case _ => Nil
}
}
@@ -50,13 +50,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* evaluated by matching hash keys.
*
* This strategy applies a simple optimization based on the estimates of the physical sizes of
- * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an
+ * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an
* estimated physical size smaller than the user-settable threshold
* [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the
* ''build'' relation and mark the other relation as the ''stream'' side. The build table will be
* ''broadcasted'' to all of the executors involved in the join, as a
* [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they
- * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]].
+ * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]].
*/
object HashJoin extends Strategy with PredicateHelper {
@@ -66,8 +66,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
left: LogicalPlan,
right: LogicalPlan,
condition: Option[Expression],
- side: BuildSide) = {
- val broadcastHashJoin = execution.BroadcastHashJoin(
+ side: joins.BuildSide) = {
+ val broadcastHashJoin = execution.joins.BroadcastHashJoin(
leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}
@@ -76,27 +76,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
+ makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
+ makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
- BuildRight
+ joins.BuildRight
} else {
- BuildLeft
+ joins.BuildLeft
}
- val hashJoin =
- execution.ShuffledHashJoin(
- leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
+ val hashJoin = joins.ShuffledHashJoin(
+ leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
- execution.HashOuterJoin(
+ joins.HashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
case _ => Nil
@@ -164,8 +163,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
val buildSide =
- if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft
- execution.BroadcastNestedLoopJoin(
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
+ joins.BuildRight
+ } else {
+ joins.BuildLeft
+ }
+ joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case _ => Nil
}
@@ -174,10 +177,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object CartesianProduct extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, _, None) =>
- execution.CartesianProduct(planLater(left), planLater(right)) :: Nil
+ execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
case logical.Join(left, right, Inner, Some(condition)) =>
execution.Filter(condition,
- execution.CartesianProduct(planLater(left), planLater(right))) :: Nil
+ execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
deleted file mode 100644
index 2890a563be..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ /dev/null
@@ -1,624 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.util.{HashMap => JavaHashMap}
-
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent._
-import scala.concurrent.duration._
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.util.collection.CompactBuffer
-
-@DeveloperApi
-sealed abstract class BuildSide
-
-@DeveloperApi
-case object BuildLeft extends BuildSide
-
-@DeveloperApi
-case object BuildRight extends BuildSide
-
-trait HashJoin {
- self: SparkPlan =>
-
- val leftKeys: Seq[Expression]
- val rightKeys: Seq[Expression]
- val buildSide: BuildSide
- val left: SparkPlan
- val right: SparkPlan
-
- lazy val (buildPlan, streamedPlan) = buildSide match {
- case BuildLeft => (left, right)
- case BuildRight => (right, left)
- }
-
- lazy val (buildKeys, streamedKeys) = buildSide match {
- case BuildLeft => (leftKeys, rightKeys)
- case BuildRight => (rightKeys, leftKeys)
- }
-
- def output = left.output ++ right.output
-
- @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output)
- @transient lazy val streamSideKeyGenerator =
- newMutableProjection(streamedKeys, streamedPlan.output)
-
- def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
- // TODO: Use Spark's HashMap implementation.
-
- val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
- var currentRow: Row = null
-
- // Create a mapping of buildKeys -> rows
- while (buildIter.hasNext) {
- currentRow = buildIter.next()
- val rowKey = buildSideKeyGenerator(currentRow)
- if (!rowKey.anyNull) {
- val existingMatchList = hashTable.get(rowKey)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[Row]()
- hashTable.put(rowKey, newMatchList)
- newMatchList
- } else {
- existingMatchList
- }
- matchList += currentRow.copy()
- }
- }
-
- new Iterator[Row] {
- private[this] var currentStreamedRow: Row = _
- private[this] var currentHashMatches: CompactBuffer[Row] = _
- private[this] var currentMatchPosition: Int = -1
-
- // Mutable per row objects.
- private[this] val joinRow = new JoinedRow2
-
- private[this] val joinKeys = streamSideKeyGenerator()
-
- override final def hasNext: Boolean =
- (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
- (streamIter.hasNext && fetchNext())
-
- override final def next() = {
- val ret = buildSide match {
- case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
- case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
- }
- currentMatchPosition += 1
- ret
- }
-
- /**
- * Searches the streamed iterator for the next row that has at least one match in hashtable.
- *
- * @return true if the search is successful, and false if the streamed iterator runs out of
- * tuples.
- */
- private final def fetchNext(): Boolean = {
- currentHashMatches = null
- currentMatchPosition = -1
-
- while (currentHashMatches == null && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- if (!joinKeys(currentStreamedRow).anyNull) {
- currentHashMatches = hashTable.get(joinKeys.currentValue)
- }
- }
-
- if (currentHashMatches == null) {
- false
- } else {
- currentMatchPosition = 0
- true
- }
- }
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Performs a hash based outer join for two child relations by shuffling the data using
- * the join keys. This operator requires loading the associated partition in both side into memory.
- */
-@DeveloperApi
-case class HashOuterJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- joinType: JoinType,
- condition: Option[Expression],
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode {
-
- override def outputPartitioning: Partitioning = joinType match {
- case LeftOuter => left.outputPartitioning
- case RightOuter => right.outputPartitioning
- case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
- case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
- }
-
- override def requiredChildDistribution =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- override def output = {
- joinType match {
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case x =>
- throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
- }
- }
-
- @transient private[this] lazy val DUMMY_LIST = Seq[Row](null)
- @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]
-
- // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
- // iterator for performance purpose.
-
- private[this] def leftOuterIterator(
- key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
- val joinedRow = new JoinedRow()
- val rightNullRow = new GenericRow(right.output.length)
- val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
-
- leftIter.iterator.flatMap { l =>
- joinedRow.withLeft(l)
- var matched = false
- (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
- matched = true
- joinedRow.copy
- } else {
- Nil
- }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
- // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all of the
- // records in right side.
- // If we didn't get any proper row, then append a single row with empty right
- joinedRow.withRight(rightNullRow).copy
- })
- }
- }
-
- private[this] def rightOuterIterator(
- key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
- val joinedRow = new JoinedRow()
- val leftNullRow = new GenericRow(left.output.length)
- val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
-
- rightIter.iterator.flatMap { r =>
- joinedRow.withRight(r)
- var matched = false
- (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
- matched = true
- joinedRow.copy
- } else {
- Nil
- }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
- // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all of the
- // records in left side.
- // If we didn't get any proper row, then append a single row with empty left.
- joinedRow.withLeft(leftNullRow).copy
- })
- }
- }
-
- private[this] def fullOuterIterator(
- key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
- val joinedRow = new JoinedRow()
- val leftNullRow = new GenericRow(left.output.length)
- val rightNullRow = new GenericRow(right.output.length)
- val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
-
- if (!key.anyNull) {
- // Store the positions of records in right, if one of its associated row satisfy
- // the join condition.
- val rightMatchedSet = scala.collection.mutable.Set[Int]()
- leftIter.iterator.flatMap[Row] { l =>
- joinedRow.withLeft(l)
- var matched = false
- rightIter.zipWithIndex.collect {
- // 1. For those matched (satisfy the join condition) records with both sides filled,
- // append them directly
-
- case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
- matched = true
- // if the row satisfy the join condition, add its index into the matched set
- rightMatchedSet.add(idx)
- joinedRow.copy
- }
- } ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
- // 2. For those unmatched records in left, append additional records with empty right.
-
- // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all
- // of the records in right side.
- // If we didn't get any proper row, then append a single row with empty right.
- joinedRow.withRight(rightNullRow).copy
- })
- } ++ rightIter.zipWithIndex.collect {
- // 3. For those unmatched records in right, append additional records with empty left.
-
- // Re-visiting the records in right, and append additional row with empty left, if its not
- // in the matched set.
- case (r, idx) if (!rightMatchedSet.contains(idx)) => {
- joinedRow(leftNullRow, r).copy
- }
- }
- } else {
- leftIter.iterator.map[Row] { l =>
- joinedRow(l, rightNullRow).copy
- } ++ rightIter.iterator.map[Row] { r =>
- joinedRow(leftNullRow, r).copy
- }
- }
- }
-
- private[this] def buildHashTable(
- iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
- val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
- while (iter.hasNext) {
- val currentRow = iter.next()
- val rowKey = keyGenerator(currentRow)
-
- var existingMatchList = hashTable.get(rowKey)
- if (existingMatchList == null) {
- existingMatchList = new CompactBuffer[Row]()
- hashTable.put(rowKey, existingMatchList)
- }
-
- existingMatchList += currentRow.copy()
- }
-
- hashTable
- }
-
- def execute() = {
- left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
- // TODO this probably can be replaced by external sort (sort merged join?)
- // Build HashMap for current partition in left relation
- val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
- // Build HashMap for current partition in right relation
- val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
-
- import scala.collection.JavaConversions._
- val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
- joinType match {
- case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
- leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
- rightHashTable.getOrElse(key, EMPTY_LIST))
- }
- case RightOuter => rightHashTable.keysIterator.flatMap { key =>
- rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
- rightHashTable.getOrElse(key, EMPTY_LIST))
- }
- case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
- fullOuterIterator(key,
- leftHashTable.getOrElse(key, EMPTY_LIST),
- rightHashTable.getOrElse(key, EMPTY_LIST))
- }
- case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
- }
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Performs an inner hash join of two child relations by first shuffling the data using the join
- * keys.
- */
-@DeveloperApi
-case class ShuffledHashJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- buildSide: BuildSide,
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashJoin {
-
- override def outputPartitioning: Partitioning = left.outputPartitioning
-
- override def requiredChildDistribution =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- def execute() = {
- buildPlan.execute().zipPartitions(streamedPlan.execute()) {
- (buildIter, streamIter) => joinIterators(buildIter, streamIter)
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
- */
-@DeveloperApi
-case class LeftSemiJoinHash(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashJoin {
-
- val buildSide = BuildRight
-
- override def requiredChildDistribution =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- override def output = left.output
-
- def execute() = {
- buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashSet = new java.util.HashSet[Row]()
- var currentRow: Row = null
-
- // Create a Hash set of buildKeys
- while (buildIter.hasNext) {
- currentRow = buildIter.next()
- val rowKey = buildSideKeyGenerator(currentRow)
- if (!rowKey.anyNull) {
- val keyExists = hashSet.contains(rowKey)
- if (!keyExists) {
- hashSet.add(rowKey)
- }
- }
- }
-
- val joinKeys = streamSideKeyGenerator()
- streamIter.filter(current => {
- !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
- })
- }
- }
-}
-
-
-/**
- * :: DeveloperApi ::
- * Performs an inner hash join of two child relations. When the output RDD of this operator is
- * being constructed, a Spark job is asynchronously started to calculate the values for the
- * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
- * relation is not shuffled.
- */
-@DeveloperApi
-case class BroadcastHashJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- buildSide: BuildSide,
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashJoin {
-
- override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
-
- override def requiredChildDistribution =
- UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
- @transient
- val broadcastFuture = future {
- sparkContext.broadcast(buildPlan.executeCollect())
- }
-
- def execute() = {
- val broadcastRelation = Await.result(broadcastFuture, 5.minute)
-
- streamedPlan.execute().mapPartitions { streamedIter =>
- joinIterators(broadcastRelation.value.iterator, streamedIter)
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
- * for hash join.
- */
-@DeveloperApi
-case class LeftSemiJoinBNL(
- streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
- extends BinaryNode {
- // TODO: Override requiredChildDistribution.
-
- override def outputPartitioning: Partitioning = streamed.outputPartitioning
-
- def output = left.output
-
- /** The Streamed Relation */
- def left = streamed
- /** The Broadcast relation */
- def right = broadcast
-
- @transient lazy val boundCondition =
- InterpretedPredicate(
- condition
- .map(c => BindReferences.bindReference(c, left.output ++ right.output))
- .getOrElse(Literal(true)))
-
- def execute() = {
- val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
-
- streamed.execute().mapPartitions { streamedIter =>
- val joinedRow = new JoinedRow
-
- streamedIter.filter(streamedRow => {
- var i = 0
- var matched = false
-
- while (i < broadcastedRelation.value.size && !matched) {
- val broadcastedRow = broadcastedRelation.value(i)
- if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
- matched = true
- }
- i += 1
- }
- matched
- })
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
- def output = left.output ++ right.output
-
- def execute() = {
- val leftResults = left.execute().map(_.copy())
- val rightResults = right.execute().map(_.copy())
-
- leftResults.cartesian(rightResults).mapPartitions { iter =>
- val joinedRow = new JoinedRow
- iter.map(r => joinedRow(r._1, r._2))
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-case class BroadcastNestedLoopJoin(
- left: SparkPlan,
- right: SparkPlan,
- buildSide: BuildSide,
- joinType: JoinType,
- condition: Option[Expression]) extends BinaryNode {
- // TODO: Override requiredChildDistribution.
-
- /** BuildRight means the right relation <=> the broadcast relation. */
- val (streamed, broadcast) = buildSide match {
- case BuildRight => (left, right)
- case BuildLeft => (right, left)
- }
-
- override def outputPartitioning: Partitioning = streamed.outputPartitioning
-
- override def output = {
- joinType match {
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case _ =>
- left.output ++ right.output
- }
- }
-
- @transient lazy val boundCondition =
- InterpretedPredicate(
- condition
- .map(c => BindReferences.bindReference(c, left.output ++ right.output))
- .getOrElse(Literal(true)))
-
- def execute() = {
- val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
-
- /** All rows that either match both-way, or rows from streamed joined with nulls. */
- val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
- val matchedRows = new CompactBuffer[Row]
- // TODO: Use Spark's BitSet.
- val includedBroadcastTuples =
- new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
- val joinedRow = new JoinedRow
- val leftNulls = new GenericMutableRow(left.output.size)
- val rightNulls = new GenericMutableRow(right.output.size)
-
- streamedIter.foreach { streamedRow =>
- var i = 0
- var streamRowMatched = false
-
- while (i < broadcastedRelation.value.size) {
- // TODO: One bitset per partition instead of per row.
- val broadcastedRow = broadcastedRelation.value(i)
- buildSide match {
- case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
- matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
- streamRowMatched = true
- includedBroadcastTuples += i
- case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
- matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
- streamRowMatched = true
- includedBroadcastTuples += i
- case _ =>
- }
- i += 1
- }
-
- (streamRowMatched, joinType, buildSide) match {
- case (false, LeftOuter | FullOuter, BuildRight) =>
- matchedRows += joinedRow(streamedRow, rightNulls).copy()
- case (false, RightOuter | FullOuter, BuildLeft) =>
- matchedRows += joinedRow(leftNulls, streamedRow).copy()
- case _ =>
- }
- }
- Iterator((matchedRows, includedBroadcastTuples))
- }
-
- val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
- val allIncludedBroadcastTuples =
- if (includedBroadcastTuples.count == 0) {
- new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
- } else {
- includedBroadcastTuples.reduce(_ ++ _)
- }
-
- val leftNulls = new GenericMutableRow(left.output.size)
- val rightNulls = new GenericMutableRow(right.output.size)
- /** Rows from broadcasted joined with nulls. */
- val broadcastRowsWithNulls: Seq[Row] = {
- val buf: CompactBuffer[Row] = new CompactBuffer()
- var i = 0
- val rel = broadcastedRelation.value
- while (i < rel.length) {
- if (!allIncludedBroadcastTuples.contains(i)) {
- (joinType, buildSide) match {
- case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
- case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
- case _ =>
- }
- }
- i += 1
- }
- buf.toSeq
- }
-
- // TODO: Breaks lineage.
- sparkContext.union(
- matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls))
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
new file mode 100644
index 0000000000..d88ab6367a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import scala.concurrent._
+import scala.concurrent.duration._
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Performs an inner hash join of two child relations. When the output RDD of this operator is
+ * being constructed, a Spark job is asynchronously started to calculate the values for the
+ * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
+ * relation is not shuffled.
+ */
+@DeveloperApi
+case class BroadcastHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ buildSide: BuildSide,
+ left: SparkPlan,
+ right: SparkPlan)
+ extends BinaryNode with HashJoin {
+
+ override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
+
+ override def requiredChildDistribution =
+ UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+
+ @transient
+ private val broadcastFuture = future {
+ sparkContext.broadcast(buildPlan.executeCollect())
+ }
+
+ override def execute() = {
+ val broadcastRelation = Await.result(broadcastFuture, 5.minute)
+
+ streamedPlan.execute().mapPartitions { streamedIter =>
+ joinIterators(broadcastRelation.value.iterator, streamedIter)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
new file mode 100644
index 0000000000..36aad13778
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.collection.CompactBuffer
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class BroadcastNestedLoopJoin(
+ left: SparkPlan,
+ right: SparkPlan,
+ buildSide: BuildSide,
+ joinType: JoinType,
+ condition: Option[Expression]) extends BinaryNode {
+ // TODO: Override requiredChildDistribution.
+
+ /** BuildRight means the right relation <=> the broadcast relation. */
+ private val (streamed, broadcast) = buildSide match {
+ case BuildRight => (left, right)
+ case BuildLeft => (right, left)
+ }
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ override def output = {
+ joinType match {
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case _ =>
+ left.output ++ right.output
+ }
+ }
+
+ @transient private lazy val boundCondition =
+ InterpretedPredicate(
+ condition
+ .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+ .getOrElse(Literal(true)))
+
+ override def execute() = {
+ val broadcastedRelation =
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+
+ /** All rows that either match both-way, or rows from streamed joined with nulls. */
+ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
+ val matchedRows = new CompactBuffer[Row]
+ // TODO: Use Spark's BitSet.
+ val includedBroadcastTuples =
+ new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
+ val joinedRow = new JoinedRow
+ val leftNulls = new GenericMutableRow(left.output.size)
+ val rightNulls = new GenericMutableRow(right.output.size)
+
+ streamedIter.foreach { streamedRow =>
+ var i = 0
+ var streamRowMatched = false
+
+ while (i < broadcastedRelation.value.size) {
+ // TODO: One bitset per partition instead of per row.
+ val broadcastedRow = broadcastedRelation.value(i)
+ buildSide match {
+ case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
+ matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
+ streamRowMatched = true
+ includedBroadcastTuples += i
+ case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
+ matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
+ streamRowMatched = true
+ includedBroadcastTuples += i
+ case _ =>
+ }
+ i += 1
+ }
+
+ (streamRowMatched, joinType, buildSide) match {
+ case (false, LeftOuter | FullOuter, BuildRight) =>
+ matchedRows += joinedRow(streamedRow, rightNulls).copy()
+ case (false, RightOuter | FullOuter, BuildLeft) =>
+ matchedRows += joinedRow(leftNulls, streamedRow).copy()
+ case _ =>
+ }
+ }
+ Iterator((matchedRows, includedBroadcastTuples))
+ }
+
+ val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
+ val allIncludedBroadcastTuples =
+ if (includedBroadcastTuples.count == 0) {
+ new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
+ } else {
+ includedBroadcastTuples.reduce(_ ++ _)
+ }
+
+ val leftNulls = new GenericMutableRow(left.output.size)
+ val rightNulls = new GenericMutableRow(right.output.size)
+ /** Rows from broadcasted joined with nulls. */
+ val broadcastRowsWithNulls: Seq[Row] = {
+ val buf: CompactBuffer[Row] = new CompactBuffer()
+ var i = 0
+ val rel = broadcastedRelation.value
+ while (i < rel.length) {
+ if (!allIncludedBroadcastTuples.contains(i)) {
+ (joinType, buildSide) match {
+ case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
+ case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
+ case _ =>
+ }
+ }
+ i += 1
+ }
+ buf.toSeq
+ }
+
+ // TODO: Breaks lineage.
+ sparkContext.union(
+ matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
new file mode 100644
index 0000000000..76c14c02aa
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.JoinedRow
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
+ override def output = left.output ++ right.output
+
+ override def execute() = {
+ val leftResults = left.execute().map(_.copy())
+ val rightResults = right.execute().map(_.copy())
+
+ leftResults.cartesian(rightResults).mapPartitions { iter =>
+ val joinedRow = new JoinedRow
+ iter.map(r => joinedRow(r._1, r._2))
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
new file mode 100644
index 0000000000..472b2e6ca6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.util.collection.CompactBuffer
+
+
+trait HashJoin {
+ self: SparkPlan =>
+
+ val leftKeys: Seq[Expression]
+ val rightKeys: Seq[Expression]
+ val buildSide: BuildSide
+ val left: SparkPlan
+ val right: SparkPlan
+
+ protected lazy val (buildPlan, streamedPlan) = buildSide match {
+ case BuildLeft => (left, right)
+ case BuildRight => (right, left)
+ }
+
+ protected lazy val (buildKeys, streamedKeys) = buildSide match {
+ case BuildLeft => (leftKeys, rightKeys)
+ case BuildRight => (rightKeys, leftKeys)
+ }
+
+ override def output = left.output ++ right.output
+
+ @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output)
+ @transient protected lazy val streamSideKeyGenerator =
+ newMutableProjection(streamedKeys, streamedPlan.output)
+
+ protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] =
+ {
+ // TODO: Use Spark's HashMap implementation.
+
+ val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
+ var currentRow: Row = null
+
+ // Create a mapping of buildKeys -> rows
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = buildSideKeyGenerator(currentRow)
+ if (!rowKey.anyNull) {
+ val existingMatchList = hashTable.get(rowKey)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new CompactBuffer[Row]()
+ hashTable.put(rowKey, newMatchList)
+ newMatchList
+ } else {
+ existingMatchList
+ }
+ matchList += currentRow.copy()
+ }
+ }
+
+ new Iterator[Row] {
+ private[this] var currentStreamedRow: Row = _
+ private[this] var currentHashMatches: CompactBuffer[Row] = _
+ private[this] var currentMatchPosition: Int = -1
+
+ // Mutable per row objects.
+ private[this] val joinRow = new JoinedRow2
+
+ private[this] val joinKeys = streamSideKeyGenerator()
+
+ override final def hasNext: Boolean =
+ (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
+ (streamIter.hasNext && fetchNext())
+
+ override final def next() = {
+ val ret = buildSide match {
+ case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+ case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
+ }
+ currentMatchPosition += 1
+ ret
+ }
+
+ /**
+ * Searches the streamed iterator for the next row that has at least one match in hashtable.
+ *
+ * @return true if the search is successful, and false if the streamed iterator runs out of
+ * tuples.
+ */
+ private final def fetchNext(): Boolean = {
+ currentHashMatches = null
+ currentMatchPosition = -1
+
+ while (currentHashMatches == null && streamIter.hasNext) {
+ currentStreamedRow = streamIter.next()
+ if (!joinKeys(currentStreamedRow).anyNull) {
+ currentHashMatches = hashTable.get(joinKeys.currentValue)
+ }
+ }
+
+ if (currentHashMatches == null) {
+ false
+ } else {
+ currentMatchPosition = 0
+ true
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
new file mode 100644
index 0000000000..b73041d306
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.{HashMap => JavaHashMap}
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.collection.CompactBuffer
+
+/**
+ * :: DeveloperApi ::
+ * Performs a hash based outer join for two child relations by shuffling the data using
+ * the join keys. This operator requires loading the associated partition in both side into memory.
+ */
+@DeveloperApi
+case class HashOuterJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def outputPartitioning: Partitioning = joinType match {
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ }
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def output = {
+ joinType match {
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case x =>
+ throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ }
+ }
+
+ @transient private[this] lazy val DUMMY_LIST = Seq[Row](null)
+ @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]
+
+ // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
+ // iterator for performance purpose.
+
+ private[this] def leftOuterIterator(
+ key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
+ val joinedRow = new JoinedRow()
+ val rightNullRow = new GenericRow(right.output.length)
+ val boundCondition =
+ condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+
+ leftIter.iterator.flatMap { l =>
+ joinedRow.withLeft(l)
+ var matched = false
+ (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
+ matched = true
+ joinedRow.copy
+ } else {
+ Nil
+ }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
+ // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
+ // as we don't know whether we need to append it until finish iterating all of the
+ // records in right side.
+ // If we didn't get any proper row, then append a single row with empty right
+ joinedRow.withRight(rightNullRow).copy
+ })
+ }
+ }
+
+ private[this] def rightOuterIterator(
+ key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
+ val joinedRow = new JoinedRow()
+ val leftNullRow = new GenericRow(left.output.length)
+ val boundCondition =
+ condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+
+ rightIter.iterator.flatMap { r =>
+ joinedRow.withRight(r)
+ var matched = false
+ (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
+ matched = true
+ joinedRow.copy
+ } else {
+ Nil
+ }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
+ // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
+ // as we don't know whether we need to append it until finish iterating all of the
+ // records in left side.
+ // If we didn't get any proper row, then append a single row with empty left.
+ joinedRow.withLeft(leftNullRow).copy
+ })
+ }
+ }
+
+ private[this] def fullOuterIterator(
+ key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
+ val joinedRow = new JoinedRow()
+ val leftNullRow = new GenericRow(left.output.length)
+ val rightNullRow = new GenericRow(right.output.length)
+ val boundCondition =
+ condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+
+ if (!key.anyNull) {
+ // Store the positions of records in right, if one of its associated row satisfy
+ // the join condition.
+ val rightMatchedSet = scala.collection.mutable.Set[Int]()
+ leftIter.iterator.flatMap[Row] { l =>
+ joinedRow.withLeft(l)
+ var matched = false
+ rightIter.zipWithIndex.collect {
+ // 1. For those matched (satisfy the join condition) records with both sides filled,
+ // append them directly
+
+ case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
+ matched = true
+ // if the row satisfy the join condition, add its index into the matched set
+ rightMatchedSet.add(idx)
+ joinedRow.copy
+ }
+ } ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
+ // 2. For those unmatched records in left, append additional records with empty right.
+
+ // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
+ // as we don't know whether we need to append it until finish iterating all
+ // of the records in right side.
+ // If we didn't get any proper row, then append a single row with empty right.
+ joinedRow.withRight(rightNullRow).copy
+ })
+ } ++ rightIter.zipWithIndex.collect {
+ // 3. For those unmatched records in right, append additional records with empty left.
+
+ // Re-visiting the records in right, and append additional row with empty left, if its not
+ // in the matched set.
+ case (r, idx) if (!rightMatchedSet.contains(idx)) => {
+ joinedRow(leftNullRow, r).copy
+ }
+ }
+ } else {
+ leftIter.iterator.map[Row] { l =>
+ joinedRow(l, rightNullRow).copy
+ } ++ rightIter.iterator.map[Row] { r =>
+ joinedRow(leftNullRow, r).copy
+ }
+ }
+ }
+
+ private[this] def buildHashTable(
+ iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
+ val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
+ while (iter.hasNext) {
+ val currentRow = iter.next()
+ val rowKey = keyGenerator(currentRow)
+
+ var existingMatchList = hashTable.get(rowKey)
+ if (existingMatchList == null) {
+ existingMatchList = new CompactBuffer[Row]()
+ hashTable.put(rowKey, existingMatchList)
+ }
+
+ existingMatchList += currentRow.copy()
+ }
+
+ hashTable
+ }
+
+ override def execute() = {
+ left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+ // TODO this probably can be replaced by external sort (sort merged join?)
+ // Build HashMap for current partition in left relation
+ val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
+ // Build HashMap for current partition in right relation
+ val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
+ val boundCondition =
+ condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+ joinType match {
+ case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
+ leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
+ rightHashTable.getOrElse(key, EMPTY_LIST))
+ }
+ case RightOuter => rightHashTable.keysIterator.flatMap { key =>
+ rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
+ rightHashTable.getOrElse(key, EMPTY_LIST))
+ }
+ case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
+ fullOuterIterator(key,
+ leftHashTable.getOrElse(key, EMPTY_LIST),
+ rightHashTable.getOrElse(key, EMPTY_LIST))
+ }
+ case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
new file mode 100644
index 0000000000..60003d1900
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
+ * for hash join.
+ */
+@DeveloperApi
+case class LeftSemiJoinBNL(
+ streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
+ extends BinaryNode {
+ // TODO: Override requiredChildDistribution.
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ override def output = left.output
+
+ /** The Streamed Relation */
+ override def left = streamed
+ /** The Broadcast relation */
+ override def right = broadcast
+
+ @transient private lazy val boundCondition =
+ InterpretedPredicate(
+ condition
+ .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+ .getOrElse(Literal(true)))
+
+ override def execute() = {
+ val broadcastedRelation =
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+
+ streamed.execute().mapPartitions { streamedIter =>
+ val joinedRow = new JoinedRow
+
+ streamedIter.filter(streamedRow => {
+ var i = 0
+ var matched = false
+
+ while (i < broadcastedRelation.value.size && !matched) {
+ val broadcastedRow = broadcastedRelation.value(i)
+ if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
+ matched = true
+ }
+ i += 1
+ }
+ matched
+ })
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
new file mode 100644
index 0000000000..ea7babf3be
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.{Expression, Row}
+import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Build the right table's join keys into a HashSet, and iteratively go through the left
+ * table, to find the if join keys are in the Hash set.
+ */
+@DeveloperApi
+case class LeftSemiJoinHash(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode with HashJoin {
+
+ override val buildSide = BuildRight
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def output = left.output
+
+ override def execute() = {
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
+ val hashSet = new java.util.HashSet[Row]()
+ var currentRow: Row = null
+
+ // Create a Hash set of buildKeys
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = buildSideKeyGenerator(currentRow)
+ if (!rowKey.anyNull) {
+ val keyExists = hashSet.contains(rowKey)
+ if (!keyExists) {
+ hashSet.add(rowKey)
+ }
+ }
+ }
+
+ val joinKeys = streamSideKeyGenerator()
+ streamIter.filter(current => {
+ !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
+ })
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
new file mode 100644
index 0000000000..8247304c1d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Performs an inner hash join of two child relations by first shuffling the data using the join
+ * keys.
+ */
+@DeveloperApi
+case class ShuffledHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ buildSide: BuildSide,
+ left: SparkPlan,
+ right: SparkPlan)
+ extends BinaryNode with HashJoin {
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def execute() = {
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) {
+ (buildIter, streamIter) => joinIterators(buildIter, streamIter)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala
new file mode 100644
index 0000000000..7f2ab1765b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Physical execution operators for join operations.
+ */
+package object joins {
+
+ @DeveloperApi
+ sealed abstract class BuildSide
+
+ @DeveloperApi
+ case object BuildRight extends BuildSide
+
+ @DeveloperApi
+ case object BuildLeft extends BuildSide
+
+}