aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorZongheng Yang <zongheng.y@gmail.com>2014-07-31 19:32:16 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-31 19:32:16 -0700
commit8f51491ea78d8e88fc664c2eac3b4ac14226d98f (patch)
tree280853242a7533e518e462806dfd83a3e653370e /sql
parentef4ff00f87a4e8d38866f163f01741c2673e41da (diff)
downloadspark-8f51491ea78d8e88fc664c2eac3b4ac14226d98f.tar.gz
spark-8f51491ea78d8e88fc664c2eac3b4ac14226d98f.tar.bz2
spark-8f51491ea78d8e88fc664c2eac3b4ac14226d98f.zip
[SPARK-2531 & SPARK-2436] [SQL] Optimize the BuildSide when planning BroadcastNestedLoopJoin.
This PR resolves the following two tickets: - [SPARK-2531](https://issues.apache.org/jira/browse/SPARK-2531): BNLJ currently assumes the build side is the right relation. This patch refactors some of its logic to take into account a BuildSide properly. - [SPARK-2436](https://issues.apache.org/jira/browse/SPARK-2436): building on top of the above, we simply use the physical size statistics (if available) of both relations, and make the smaller relation the build side in the planner. Author: Zongheng Yang <zongheng.y@gmail.com> Closes #1448 from concretevitamin/bnlj-buildSide and squashes the following commits: 1780351 [Zongheng Yang] Use size estimation to decide optimal build side of BNLJ. 68e6c5b [Zongheng Yang] Consolidate two adjacent pattern matchings. 96d312a [Zongheng Yang] Use a while loop instead of collection methods chaining. 4bc525e [Zongheng Yang] Make BroadcastNestedLoopJoin take a BuildSide.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala79
2 files changed, 55 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5f1fe99f75..d57b6eaf40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -155,8 +155,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object BroadcastNestedLoopJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
+ val buildSide =
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft
execution.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), joinType, condition) :: Nil
+ planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 2750ddbce8..b068579db7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -314,10 +314,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
*/
@DeveloperApi
case class BroadcastNestedLoopJoin(
- streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
- extends BinaryNode {
+ left: SparkPlan,
+ right: SparkPlan,
+ buildSide: BuildSide,
+ joinType: JoinType,
+ condition: Option[Expression]) extends BinaryNode {
// TODO: Override requiredChildDistribution.
+ /** BuildRight means the right relation <=> the broadcast relation. */
+ val (streamed, broadcast) = buildSide match {
+ case BuildRight => (left, right)
+ case BuildLeft => (right, left)
+ }
+
override def outputPartitioning: Partitioning = streamed.outputPartitioning
override def output = {
@@ -333,11 +342,6 @@ case class BroadcastNestedLoopJoin(
}
}
- /** The Streamed Relation */
- def left = streamed
- /** The Broadcast relation */
- def right = broadcast
-
@transient lazy val boundCondition =
InterpretedPredicate(
condition
@@ -348,57 +352,78 @@ case class BroadcastNestedLoopJoin(
val broadcastedRelation =
sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
- val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
+ /** All rows that either match both-way, or rows from streamed joined with nulls. */
+ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row]
// TODO: Use Spark's BitSet.
- val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
+ val includedBroadcastTuples =
+ new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow
+ val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
streamedIter.foreach { streamedRow =>
var i = 0
- var matched = false
+ var streamRowMatched = false
while (i < broadcastedRelation.value.size) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
- if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
- matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
- matched = true
- includedBroadcastTuples += i
+ buildSide match {
+ case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
+ matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
+ streamRowMatched = true
+ includedBroadcastTuples += i
+ case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
+ matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
+ streamRowMatched = true
+ includedBroadcastTuples += i
+ case _ =>
}
i += 1
}
- if (!matched && (joinType == LeftOuter || joinType == FullOuter)) {
- matchedRows += joinedRow(streamedRow, rightNulls).copy()
+ (streamRowMatched, joinType, buildSide) match {
+ case (false, LeftOuter | FullOuter, BuildRight) =>
+ matchedRows += joinedRow(streamedRow, rightNulls).copy()
+ case (false, RightOuter | FullOuter, BuildLeft) =>
+ matchedRows += joinedRow(leftNulls, streamedRow).copy()
+ case _ =>
}
}
Iterator((matchedRows, includedBroadcastTuples))
}
- val includedBroadcastTuples = streamedPlusMatches.map(_._2)
+ val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
val allIncludedBroadcastTuples =
if (includedBroadcastTuples.count == 0) {
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
} else {
- streamedPlusMatches.map(_._2).reduce(_ ++ _)
+ includedBroadcastTuples.reduce(_ ++ _)
}
val leftNulls = new GenericMutableRow(left.output.size)
- val rightOuterMatches: Seq[Row] =
- if (joinType == RightOuter || joinType == FullOuter) {
- broadcastedRelation.value.zipWithIndex.filter {
- case (row, i) => !allIncludedBroadcastTuples.contains(i)
- }.map {
- case (row, _) => new JoinedRow(leftNulls, row)
+ val rightNulls = new GenericMutableRow(right.output.size)
+ /** Rows from broadcasted joined with nulls. */
+ val broadcastRowsWithNulls: Seq[Row] = {
+ val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer()
+ var i = 0
+ val rel = broadcastedRelation.value
+ while (i < rel.length) {
+ if (!allIncludedBroadcastTuples.contains(i)) {
+ (joinType, buildSide) match {
+ case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i))
+ case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls)
+ case _ =>
+ }
}
- } else {
- Vector()
+ i += 1
}
+ arrBuf.toSeq
+ }
// TODO: Breaks lineage.
sparkContext.union(
- streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches))
+ matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls))
}
}