aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala11
6 files changed, 96 insertions, 72 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 910519d0e6..df0f730499 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.{execution, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object EquiJoinSelection extends Strategy with PredicateHelper {
- private[this] def makeBroadcastHashJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- left: LogicalPlan,
- right: LogicalPlan,
- condition: Option[Expression],
- side: joins.BuildSide): Seq[SparkPlan] = {
- val broadcastHashJoin = execution.joins.BroadcastHashJoin(
- leftKeys, rightKeys, side, planLater(left), planLater(right))
- condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
- }
-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// --- Inner joins --------------------------------------------------------------------------
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
+ joins.BroadcastHashJoin(
+ leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
+ joins.BroadcastHashJoin(
+ leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
- val mergeJoin =
- joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
- condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
+ joins.SortMergeJoin(
+ leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil
// --- Outer joins --------------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 0a818cc2c2..c9ea579b5e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -39,6 +39,7 @@ case class BroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
+ condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 7f9d9daa5a..8ef854001f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.joins
+import java.util.NoSuchElementException
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
@@ -29,6 +31,7 @@ trait HashJoin {
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val buildSide: BuildSide
+ val condition: Option[Expression]
val left: SparkPlan
val right: SparkPlan
@@ -50,6 +53,12 @@ trait HashJoin {
protected def streamSideKeyGenerator: Projection =
UnsafeProjection.create(streamedKeys, streamedPlan.output)
+ @transient private[this] lazy val boundCondition = if (condition.isDefined) {
+ newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+ } else {
+ (r: InternalRow) => true
+ }
+
protected def hashJoin(
streamIter: Iterator[InternalRow],
numStreamRows: LongSQLMetric,
@@ -68,44 +77,52 @@ trait HashJoin {
private[this] val joinKeys = streamSideKeyGenerator
- override final def hasNext: Boolean =
- (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
- (streamIter.hasNext && fetchNext())
+ override final def hasNext: Boolean = {
+ while (true) {
+ // check if it's end of current matches
+ if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
+ currentHashMatches = null
+ currentMatchPosition = -1
+ }
- override final def next(): InternalRow = {
- val ret = buildSide match {
- case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
- case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
- }
- currentMatchPosition += 1
- numOutputRows += 1
- resultProjection(ret)
- }
+ // find the next match
+ while (currentHashMatches == null && streamIter.hasNext) {
+ currentStreamedRow = streamIter.next()
+ numStreamRows += 1
+ val key = joinKeys(currentStreamedRow)
+ if (!key.anyNull) {
+ currentHashMatches = hashedRelation.get(key)
+ if (currentHashMatches != null) {
+ currentMatchPosition = 0
+ }
+ }
+ }
+ if (currentHashMatches == null) {
+ return false
+ }
- /**
- * Searches the streamed iterator for the next row that has at least one match in hashtable.
- *
- * @return true if the search is successful, and false if the streamed iterator runs out of
- * tuples.
- */
- private final def fetchNext(): Boolean = {
- currentHashMatches = null
- currentMatchPosition = -1
-
- while (currentHashMatches == null && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- numStreamRows += 1
- val key = joinKeys(currentStreamedRow)
- if (!key.anyNull) {
- currentHashMatches = hashedRelation.get(key)
+ // found some matches
+ buildSide match {
+ case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+ case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
+ }
+ if (boundCondition(joinRow)) {
+ return true
+ } else {
+ currentMatchPosition += 1
}
}
+ false // unreachable
+ }
- if (currentHashMatches == null) {
- false
+ override final def next(): InternalRow = {
+ // next() could be called without calling hasNext()
+ if (hasNext) {
+ currentMatchPosition += 1
+ numOutputRows += 1
+ resultProjection(joinRow)
} else {
- currentMatchPosition = 0
- true
+ throw new NoSuchElementException
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 6d464d6946..9e614309de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -78,8 +78,11 @@ trait HashOuterJoin {
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
- @transient private[this] lazy val boundCondition =
+ @transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+ } else {
+ (row: InternalRow) => true
+ }
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 812f881d06..322a954b4f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
case class SortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
+ condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
@@ -64,6 +65,13 @@ case class SortMergeJoin(
val numOutputRows = longMetric("numOutputRows")
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+ val boundCondition: (InternalRow) => Boolean = {
+ condition.map { cond =>
+ newPredicate(cond, left.output ++ right.output)
+ }.getOrElse {
+ (r: InternalRow) => true
+ }
+ }
new RowIterator {
// The projection used to extract keys from input rows of the left child.
private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
@@ -89,26 +97,34 @@ case class SortMergeJoin(
private[this] val resultProjection: (InternalRow) => InternalRow =
UnsafeProjection.create(schema)
+ if (smjScanner.findNextInnerJoinRows()) {
+ currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ currentMatchIdx = 0
+ }
+
override def advanceNext(): Boolean = {
- if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
- if (smjScanner.findNextInnerJoinRows()) {
- currentRightMatches = smjScanner.getBufferedMatches
- currentLeftRow = smjScanner.getStreamedRow
- currentMatchIdx = 0
- } else {
- currentRightMatches = null
- currentLeftRow = null
- currentMatchIdx = -1
+ while (currentMatchIdx >= 0) {
+ if (currentMatchIdx == currentRightMatches.length) {
+ if (smjScanner.findNextInnerJoinRows()) {
+ currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ currentMatchIdx = 0
+ } else {
+ currentRightMatches = null
+ currentLeftRow = null
+ currentMatchIdx = -1
+ return false
+ }
}
- }
- if (currentLeftRow != null) {
joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
currentMatchIdx += 1
- numOutputRows += 1
- true
- } else {
- false
+ if (boundCondition(joinRow)) {
+ numOutputRows += 1
+ return true
+ }
}
+ false
}
override def getRow: InternalRow = resultProjection(joinRow)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 42fadaa8e2..ab81b70259 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
@@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.{DataFrame, Row, SQLConf}
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.localSeqToDataFrameHolder
@@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- val broadcastHashJoin =
- execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
- boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
+ joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
}
def makeSortMergeJoin(
@@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
val sortMergeJoin =
- execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
- val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
- EnsureRequirements(sqlContext).apply(filteredJoin)
+ joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
+ EnsureRequirements(sqlContext).apply(sortMergeJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {