From d76592276f9f66fed8012d876595de8717f516a9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 6 Apr 2016 19:25:10 -0700 Subject: [SPARK-12610][SQL] Left Anti Join ### What changes were proposed in this pull request? This PR adds support for `LEFT ANTI JOIN` to Spark SQL. A `LEFT ANTI JOIN` is the exact opposite of a `LEFT SEMI JOIN` and can be used to identify rows in one dataset that are not in another dataset. Note that `nulls` on the left side of the join cannot match a row on the right hand side of the join; the result is that left anti join will always select a row with a `null` in one or more of its keys. We currently add support for the following SQL join syntax: SELECT * FROM tbl1 A LEFT ANTI JOIN tbl2 B ON A.Id = B.Id Or using a dataframe: tbl1.as("a").join(tbl2.as("b"), $"a.id" === $"b.id", "left_anti) This PR provides serves as the basis for implementing `NOT EXISTS` and `NOT IN (...)` correlated sub-queries. It would also serve as good basis for implementing an more efficient `EXCEPT` operator. The PR has been (losely) based on PR's by both davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/10563); credit should be given where credit is due. This PR adds supports for `LEFT ANTI JOIN` to `BroadcastHashJoin` (including codegeneration), `ShuffledHashJoin` and `BroadcastNestedLoopJoin`. ### How was this patch tested? Added tests to `JoinSuite` and ported `ExistenceJoinSuite` from https://github.com/apache/spark/pull/10563. cc davies chenghao-intel rxin Author: Herman van Hovell Closes #12214 from hvanhovell/SPARK-12610. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 + .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 8 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 1 + .../spark/sql/catalyst/plans/joinTypes.scala | 17 ++- .../catalyst/plans/logical/basicOperators.scala | 4 +- .../sql/catalyst/parser/PlanParserSuite.scala | 5 +- .../apache/spark/sql/execution/SparkPlanner.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 11 +- .../sql/execution/joins/BroadcastHashJoin.scala | 99 +++++++++---- .../execution/joins/BroadcastNestedLoopJoin.scala | 57 +++++--- .../spark/sql/execution/joins/HashJoin.scala | 18 ++- .../sql/execution/joins/ShuffledHashJoin.scala | 1 + .../scala/org/apache/spark/sql/JoinSuite.scala | 36 ++--- .../sql/execution/joins/ExistenceJoinSuite.scala | 159 +++++++++++++++++++++ .../spark/sql/execution/joins/SemiJoinSuite.scala | 129 ----------------- .../apache/spark/sql/hive/HiveSessionState.scala | 2 +- 17 files changed, 338 insertions(+), 215 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 8a45b4f2e1..85cb585919 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -380,6 +380,7 @@ joinType | LEFT SEMI | RIGHT OUTER? | FULL OUTER? + | LEFT? ANTI ; joinCriteria @@ -878,6 +879,7 @@ INDEX: 'INDEX'; INDEXES: 'INDEXES'; LOCKS: 'LOCKS'; OPTION: 'OPTION'; +ANTI: 'ANTI'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 473c91e69e..bc8cf4e78a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1424,7 +1424,7 @@ class Analyzer( val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) - case LeftSemi => + case LeftExistence(_) => leftKeys ++ lUniqueOutput case RightOuter => rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c085a377ff..f581810c26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -361,8 +361,8 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case j @ Join(left, right, LeftSemi, condition) => + // Eliminate unneeded attributes from right side of a Left Existence Join. + case j @ Join(left, right, LeftExistence(_), condition) => j.copy(right = prunedChild(right, j.references)) // all the columns will be used to compare, so we can't prune them @@ -1126,7 +1126,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) - case _ @ (LeftOuter | LeftSemi) => + case LeftOuter | LeftExistence(_) => // push down the left side only `where` condition val newLeft = leftFilterConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -1147,7 +1147,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case _ @ (Inner | LeftSemi) => + case Inner | LeftExistence(_) => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5a3aebff09..aa59f3fb2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -572,6 +572,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case null => Inner case jt if jt.FULL != null => FullOuter case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti case jt if jt.LEFT != null => LeftOuter case jt if jt.RIGHT != null => RightOuter case _ => Inner diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 9ca4f13dd7..13f57c54a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -26,13 +26,15 @@ object JoinType { case "leftouter" | "left" => LeftOuter case "rightouter" | "right" => RightOuter case "leftsemi" => LeftSemi + case "leftanti" => LeftAnti case _ => val supported = Seq( "inner", "outer", "full", "fullouter", "leftouter", "left", "rightouter", "right", - "leftsemi") + "leftsemi", + "leftanti") throw new IllegalArgumentException(s"Unsupported join type '$typ'. " + "Supported join types include: " + supported.mkString("'", "', '", "'") + ".") @@ -63,6 +65,10 @@ case object LeftSemi extends JoinType { override def sql: String = "LEFT SEMI" } +case object LeftAnti extends JoinType { + override def sql: String = "LEFT ANTI" +} + case class NaturalJoin(tpe: JoinType) extends JoinType { require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), "Unsupported natural join type " + tpe) @@ -70,7 +76,14 @@ case class NaturalJoin(tpe: JoinType) extends JoinType { } case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType { - require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe), + require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter, LeftAnti).contains(tpe), "Unsupported using join type " + tpe) override def sql: String = "USING " + tpe.sql } + +object LeftExistence { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a18efc90ab..d3353beb09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -252,7 +252,7 @@ case class Join( override def output: Seq[Attribute] = { joinType match { - case LeftSemi => + case LeftExistence(_) => left.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -276,7 +276,7 @@ case class Join( .union(splitConjunctivePredicates(condition.get).toSet) case Inner => left.constraints.union(right.constraints) - case LeftSemi => + case LeftExistence(_) => left.constraints case LeftOuter => left.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 262537d9c7..411e2372f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -334,7 +334,7 @@ class PlanParserSuite extends PlanTest { table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) } val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) - + val testExistence = Seq(testUnconditionalJoin, testConditionalJoin, testUsingJoin) def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { tests.foreach(_(sql, jt)) } @@ -348,6 +348,9 @@ class PlanParserSuite extends PlanTest { test("right outer join", RightOuter, testAll) test("full join", FullOuter, testAll) test("full outer join", FullOuter, testAll) + test("left semi join", LeftSemi, testExistence) + test("left anti join", LeftAnti, testExistence) + test("anti join", LeftAnti, testExistence) // Test multiple consecutive joins assertEqual( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index ac8072f3ca..8d05ae470d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -38,7 +38,7 @@ class SparkPlanner( DDLStrategy :: SpecialLimits :: Aggregation :: - LeftSemiJoin :: + ExistenceJoin :: EquiJoinSelection :: InMemoryScans :: BasicOperators :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d77aba7260..eee2b946e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -62,16 +62,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object LeftSemiJoin extends Strategy with PredicateHelper { + object ExistenceJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys( - LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) => Seq(joins.BroadcastHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) // Find left semi joins where at least some predicates can be evaluated by matching join keys - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys( + LeftExistence(jt), leftKeys, rightKeys, condition, left, right) => Seq(joins.ShuffledHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 67ac9e94ff..e3d554c2de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -87,6 +86,7 @@ case class BroadcastHashJoin( case Inner => codegenInner(ctx, input) case LeftOuter | RightOuter => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) + case LeftAnti => codegenAnti(ctx, input) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") @@ -160,15 +160,14 @@ case class BroadcastHashJoin( } /** - * Generates the code for Inner join. + * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi + * and Left Anti joins. */ - private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + private def getJoinCondition( + ctx: CodegenContext, + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) - val numOutput = metricTerm(ctx, "numOutputRows") - val checkCondition = if (condition.isDefined) { val expr = condition.get // evaluate the variables from build side that used by condition @@ -184,6 +183,17 @@ case class BroadcastHashJoin( } else { "" } + (matched, checkCondition, buildVars) + } + + /** + * Generates the code for Inner join. + */ + private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") val resultVars = buildSide match { case BuildLeft => buildVars ++ input @@ -221,7 +231,6 @@ case class BroadcastHashJoin( } } - /** * Generates the code for left or right outer join. */ @@ -276,7 +285,6 @@ case class BroadcastHashJoin( ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val i = ctx.freshName("i") val found = ctx.freshName("found") s""" |// generate join key for stream side @@ -304,26 +312,8 @@ case class BroadcastHashJoin( private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") - - val checkCondition = if (condition.isDefined) { - val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) - // filter the output via condition - ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) - s""" - |$eval - |${ev.code} - |if (${ev.isNull} || !${ev.value}) continue; - """.stripMargin - } else { - "" - } - if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side @@ -357,4 +347,57 @@ case class BroadcastHashJoin( """.stripMargin } } + + /** + * Generates the code for anti join. + */ + private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + | if ($matched != null) { + | // Evaluate the condition. + | $checkCondition + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); + | if ($matches != null) { + | // Evaluate the condition. + | boolean $found = false; + | while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $found = true; + | } + | if ($found) continue; + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 4143e944e5..4ba710c10a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -73,7 +73,7 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case LeftSemi => + case LeftExistence(_) => left.output case x => throw new IllegalArgumentException( @@ -175,8 +175,11 @@ case class BroadcastNestedLoopJoin( * The implementation for these joins: * * LeftSemi with BuildRight + * Anti with BuildRight */ - private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + private def leftExistenceJoin( + relation: Broadcast[Array[InternalRow]], + exists: Boolean): RDD[InternalRow] = { assert(buildSide == BuildRight) streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value @@ -184,10 +187,12 @@ case class BroadcastNestedLoopJoin( if (condition.isDefined) { streamedIter.filter(l => - buildRows.exists(r => boundCondition(joinedRow(l, r))) + buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists ) + } else if (buildRows.nonEmpty == exists) { + streamedIter } else { - streamedIter.filter(r => !buildRows.isEmpty) + Iterator.empty } } } @@ -199,6 +204,7 @@ case class BroadcastNestedLoopJoin( * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft + * Anti with BuildLeft */ private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { /** All rows that either match both-way, or rows from streamed joined with nulls. */ @@ -236,7 +242,27 @@ case class BroadcastNestedLoopJoin( } i += 1 } - return sparkContext.makeRDD(buf.toSeq) + return sparkContext.makeRDD(buf) + } + + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 + } + buf + } + + if (joinType == LeftAnti) { + return sparkContext.makeRDD(notMatchedBroadcastRows) } val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => @@ -264,22 +290,6 @@ case class BroadcastNestedLoopJoin( } } - val notMatchedBroadcastRows: Seq[InternalRow] = { - val nulls = new GenericMutableRow(streamed.output.size) - val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val buildRows = relation.value - val joinedRow = new JoinedRow - joinedRow.withLeft(nulls) - while (i < buildRows.length) { - if (!matchedBroadcastRows.get(i)) { - buf += joinedRow.withRight(buildRows(i)).copy() - } - i += 1 - } - buf.toSeq - } - sparkContext.union( matchedStreamRows, sparkContext.makeRDD(notMatchedBroadcastRows) @@ -295,13 +305,16 @@ case class BroadcastNestedLoopJoin( case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) case (LeftSemi, BuildRight) => - leftSemiJoin(broadcastedRelation) + leftExistenceJoin(broadcastedRelation, exists = true) + case (LeftAnti, BuildRight) => + leftExistenceJoin(broadcastedRelation, exists = false) case _ => /** * LeftOuter with BuildLeft * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft + * Anti with BuildLeft */ defaultJoin(broadcastedRelation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index b7c0f3e7d1..8f45d57126 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -47,7 +47,7 @@ trait HashJoin { left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output - case LeftSemi => + case LeftExistence(_) => left.output case x => throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") @@ -197,6 +197,20 @@ trait HashJoin { } } + private def antiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter { current => + val key = joinKeys(current) + lazy val buildIter = hashedRelation.get(key) + key.anyNull || buildIter == null || (condition.isDefined && !buildIter.exists { + row => boundCondition(joinedRow(current, row)) + }) + } + } + protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, @@ -209,6 +223,8 @@ trait HashJoin { outerJoin(streamedIter, hashed) case LeftSemi => semiJoin(streamedIter, hashed) + case LeftAnti => + antiJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index c63faacf33..bf86096379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,6 +45,7 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = joinType match { case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftAnti => left.outputPartitioning case LeftSemi => left.outputPartitioning case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a5a4ff13de..a87a41c126 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -41,7 +41,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { assert(planned.size === 1) } - def assertJoin(sqlString: String, c: Class[_]): Any = { + def assertJoin(pair: (String, Class[_])): Any = { + val (sqlString, c) = pair val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { @@ -53,8 +54,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + if (operators.head.getClass != c) { + fail(s"$sqlString expected operator: $c, but got ${operators.head}\n physical: \n$physical") } } @@ -93,8 +94,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoin]) + ).foreach(assertJoin) } } @@ -114,7 +117,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) sql("UNCACHE TABLE testData") } @@ -129,7 +132,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) sql("UNCACHE TABLE testData") } @@ -419,25 +422,22 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, 10)) } - test("broadcasted left semi join operator selection") { + test("broadcasted existence join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastHashJoin]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) - } + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData ANT JOIN testData2 ON key = a", classOf[BroadcastHashJoin]) + ).foreach(assertJoin) } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) - } + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) + ).foreach(assertJoin) } sql("UNCACHE TABLE testData") @@ -489,7 +489,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) checkAnswer( sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala new file mode 100644 index 0000000000..8cdfa8afd0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + private lazy val conditionNEQ = { + And((left.col("a") < right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testExistenceJoin( + testName: String, + joinType: JoinType, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, + expectedAnswer: Seq[Row]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + ShuffledHashJoin( + leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastHashJoin( + leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition))), + expectedAnswer, + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition))), + expectedAnswer, + sortAnswers = true) + } + } + } + + testExistenceJoin( + "basic test for left semi join", + LeftSemi, + left, + right, + condition, + Seq(Row(2, 1.0), Row(2, 1.0))) + + testExistenceJoin( + "basic test for left semi non equal join", + LeftSemi, + left, + right, + conditionNEQ, + Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0))) + + testExistenceJoin( + "basic test for anti join", + LeftAnti, + left, + right, + condition, + Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) + + testExistenceJoin( + "basic test for anti non equal join", + LeftAnti, + left, + right, + conditionNEQ, + Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala deleted file mode 100644 index 985a96f684..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} -import org.apache.spark.sql.execution.exchange.EnsureRequirements -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} - -class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { - - private lazy val left = sqlContext.createDataFrame( - sparkContext.parallelize(Seq( - Row(1, 2.0), - Row(1, 2.0), - Row(2, 1.0), - Row(2, 1.0), - Row(3, 3.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("a", IntegerType).add("b", DoubleType)) - - private lazy val right = sqlContext.createDataFrame( - sparkContext.parallelize(Seq( - Row(2, 3.0), - Row(2, 3.0), - Row(3, 2.0), - Row(4, 1.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("c", IntegerType).add("d", DoubleType)) - - private lazy val condition = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) - } - - // Note: the input dataframes and expression must be evaluated lazily because - // the SQLContext should be used only within a test to keep SQL tests stable - private def testLeftSemiJoin( - testName: String, - leftRows: => DataFrame, - rightRows: => DataFrame, - condition: => Expression, - expectedAnswer: Seq[Product]): Unit = { - - def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join) - } - - test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext.sessionState.conf).apply( - ShuffledHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using BroadcastHashJoin") { - extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using BroadcastNestedLoopJoin build left") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - testLeftSemiJoin( - "basic test", - left, - right, - condition, - Seq( - (2, 1.0), - (2, 1.0) - ) - ) -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index cff24e28fd..b992fda18c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -92,7 +92,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) DataSinks, Scripts, Aggregation, - LeftSemiJoin, + ExistenceJoin, EquiJoinSelection, BasicOperators, BroadcastNestedLoop, -- cgit v1.2.3