aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-18 17:29:54 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-18 17:29:54 -0800
commit323d51f1dadf733e413203d678cb3f76e4d68981 (patch)
treef15f31a8c2bdf7ae73c0a7ff5c4a2a0ec175fa93
parent39ac56fc60734d0e095314fc38a7b36fbb4c80f7 (diff)
downloadspark-323d51f1dadf733e413203d678cb3f76e4d68981.tar.gz
spark-323d51f1dadf733e413203d678cb3f76e4d68981.tar.bz2
spark-323d51f1dadf733e413203d678cb3f76e4d68981.zip
[SPARK-12700] [SQL] embed condition into SMJ and BroadcastHashJoin
Currently SortMergeJoin and BroadcastHashJoin do not support condition, the need a followed Filter for that, the result projection to generate UnsafeRow could be very expensive if they generate lots of rows and could be filtered mostly by condition. This PR brings the support of condition for SortMergeJoin and BroadcastHashJoin, just like other outer joins do. This could improve the performance of Q72 by 7x (from 120s to 16.5s). Author: Davies Liu <davies@databricks.com> Closes #10653 from davies/filter_join.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala11
6 files changed, 96 insertions, 72 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 910519d0e6..df0f730499 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.{execution, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object EquiJoinSelection extends Strategy with PredicateHelper {
- private[this] def makeBroadcastHashJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- left: LogicalPlan,
- right: LogicalPlan,
- condition: Option[Expression],
- side: joins.BuildSide): Seq[SparkPlan] = {
- val broadcastHashJoin = execution.joins.BroadcastHashJoin(
- leftKeys, rightKeys, side, planLater(left), planLater(right))
- condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
- }
-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// --- Inner joins --------------------------------------------------------------------------
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
+ joins.BroadcastHashJoin(
+ leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
+ joins.BroadcastHashJoin(
+ leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
- val mergeJoin =
- joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
- condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
+ joins.SortMergeJoin(
+ leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil
// --- Outer joins --------------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 0a818cc2c2..c9ea579b5e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -39,6 +39,7 @@ case class BroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
+ condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 7f9d9daa5a..8ef854001f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.joins
+import java.util.NoSuchElementException
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
@@ -29,6 +31,7 @@ trait HashJoin {
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val buildSide: BuildSide
+ val condition: Option[Expression]
val left: SparkPlan
val right: SparkPlan
@@ -50,6 +53,12 @@ trait HashJoin {
protected def streamSideKeyGenerator: Projection =
UnsafeProjection.create(streamedKeys, streamedPlan.output)
+ @transient private[this] lazy val boundCondition = if (condition.isDefined) {
+ newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+ } else {
+ (r: InternalRow) => true
+ }
+
protected def hashJoin(
streamIter: Iterator[InternalRow],
numStreamRows: LongSQLMetric,
@@ -68,44 +77,52 @@ trait HashJoin {
private[this] val joinKeys = streamSideKeyGenerator
- override final def hasNext: Boolean =
- (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
- (streamIter.hasNext && fetchNext())
+ override final def hasNext: Boolean = {
+ while (true) {
+ // check if it's end of current matches
+ if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
+ currentHashMatches = null
+ currentMatchPosition = -1
+ }
- override final def next(): InternalRow = {
- val ret = buildSide match {
- case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
- case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
- }
- currentMatchPosition += 1
- numOutputRows += 1
- resultProjection(ret)
- }
+ // find the next match
+ while (currentHashMatches == null && streamIter.hasNext) {
+ currentStreamedRow = streamIter.next()
+ numStreamRows += 1
+ val key = joinKeys(currentStreamedRow)
+ if (!key.anyNull) {
+ currentHashMatches = hashedRelation.get(key)
+ if (currentHashMatches != null) {
+ currentMatchPosition = 0
+ }
+ }
+ }
+ if (currentHashMatches == null) {
+ return false
+ }
- /**
- * Searches the streamed iterator for the next row that has at least one match in hashtable.
- *
- * @return true if the search is successful, and false if the streamed iterator runs out of
- * tuples.
- */
- private final def fetchNext(): Boolean = {
- currentHashMatches = null
- currentMatchPosition = -1
-
- while (currentHashMatches == null && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- numStreamRows += 1
- val key = joinKeys(currentStreamedRow)
- if (!key.anyNull) {
- currentHashMatches = hashedRelation.get(key)
+ // found some matches
+ buildSide match {
+ case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+ case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
+ }
+ if (boundCondition(joinRow)) {
+ return true
+ } else {
+ currentMatchPosition += 1
}
}
+ false // unreachable
+ }
- if (currentHashMatches == null) {
- false
+ override final def next(): InternalRow = {
+ // next() could be called without calling hasNext()
+ if (hasNext) {
+ currentMatchPosition += 1
+ numOutputRows += 1
+ resultProjection(joinRow)
} else {
- currentMatchPosition = 0
- true
+ throw new NoSuchElementException
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 6d464d6946..9e614309de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -78,8 +78,11 @@ trait HashOuterJoin {
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
- @transient private[this] lazy val boundCondition =
+ @transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+ } else {
+ (row: InternalRow) => true
+ }
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 812f881d06..322a954b4f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
case class SortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
+ condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
@@ -64,6 +65,13 @@ case class SortMergeJoin(
val numOutputRows = longMetric("numOutputRows")
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+ val boundCondition: (InternalRow) => Boolean = {
+ condition.map { cond =>
+ newPredicate(cond, left.output ++ right.output)
+ }.getOrElse {
+ (r: InternalRow) => true
+ }
+ }
new RowIterator {
// The projection used to extract keys from input rows of the left child.
private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
@@ -89,26 +97,34 @@ case class SortMergeJoin(
private[this] val resultProjection: (InternalRow) => InternalRow =
UnsafeProjection.create(schema)
+ if (smjScanner.findNextInnerJoinRows()) {
+ currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ currentMatchIdx = 0
+ }
+
override def advanceNext(): Boolean = {
- if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
- if (smjScanner.findNextInnerJoinRows()) {
- currentRightMatches = smjScanner.getBufferedMatches
- currentLeftRow = smjScanner.getStreamedRow
- currentMatchIdx = 0
- } else {
- currentRightMatches = null
- currentLeftRow = null
- currentMatchIdx = -1
+ while (currentMatchIdx >= 0) {
+ if (currentMatchIdx == currentRightMatches.length) {
+ if (smjScanner.findNextInnerJoinRows()) {
+ currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ currentMatchIdx = 0
+ } else {
+ currentRightMatches = null
+ currentLeftRow = null
+ currentMatchIdx = -1
+ return false
+ }
}
- }
- if (currentLeftRow != null) {
joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
currentMatchIdx += 1
- numOutputRows += 1
- true
- } else {
- false
+ if (boundCondition(joinRow)) {
+ numOutputRows += 1
+ return true
+ }
}
+ false
}
override def getRow: InternalRow = resultProjection(joinRow)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 42fadaa8e2..ab81b70259 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
@@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.{DataFrame, Row, SQLConf}
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.localSeqToDataFrameHolder
@@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- val broadcastHashJoin =
- execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
- boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
+ joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
}
def makeSortMergeJoin(
@@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
val sortMergeJoin =
- execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
- val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
- EnsureRequirements(sqlContext).apply(filteredJoin)
+ joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
+ EnsureRequirements(sqlContext).apply(sortMergeJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {