aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-04-06 19:25:10 -0700
committerReynold Xin <rxin@databricks.com>2016-04-06 19:25:10 -0700
commitd76592276f9f66fed8012d876595de8717f516a9 (patch)
treebb3570eac8b6885efe77677d18cda30df7cb0a69 /sql
parent4901086fea969a34ec312ef4a8f83d84e1bf21fb (diff)
downloadspark-d76592276f9f66fed8012d876595de8717f516a9.tar.gz
spark-d76592276f9f66fed8012d876595de8717f516a9.tar.bz2
spark-d76592276f9f66fed8012d876595de8717f516a9.zip
[SPARK-12610][SQL] Left Anti Join
### What changes were proposed in this pull request? This PR adds support for `LEFT ANTI JOIN` to Spark SQL. A `LEFT ANTI JOIN` is the exact opposite of a `LEFT SEMI JOIN` and can be used to identify rows in one dataset that are not in another dataset. Note that `nulls` on the left side of the join cannot match a row on the right hand side of the join; the result is that left anti join will always select a row with a `null` in one or more of its keys. We currently add support for the following SQL join syntax: SELECT * FROM tbl1 A LEFT ANTI JOIN tbl2 B ON A.Id = B.Id Or using a dataframe: tbl1.as("a").join(tbl2.as("b"), $"a.id" === $"b.id", "left_anti) This PR provides serves as the basis for implementing `NOT EXISTS` and `NOT IN (...)` correlated sub-queries. It would also serve as good basis for implementing an more efficient `EXCEPT` operator. The PR has been (losely) based on PR's by both davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/10563); credit should be given where credit is due. This PR adds supports for `LEFT ANTI JOIN` to `BroadcastHashJoin` (including codegeneration), `ShuffledHashJoin` and `BroadcastNestedLoopJoin`. ### How was this patch tested? Added tests to `JoinSuite` and ported `ExistenceJoinSuite` from https://github.com/apache/spark/pull/10563. cc davies chenghao-intel rxin Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12214 from hvanhovell/SPARK-12610.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala99
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala57
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala (renamed from sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala)74
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala2
16 files changed, 231 insertions, 108 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 8a45b4f2e1..85cb585919 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -380,6 +380,7 @@ joinType
| LEFT SEMI
| RIGHT OUTER?
| FULL OUTER?
+ | LEFT? ANTI
;
joinCriteria
@@ -878,6 +879,7 @@ INDEX: 'INDEX';
INDEXES: 'INDEXES';
LOCKS: 'LOCKS';
OPTION: 'OPTION';
+ANTI: 'ANTI';
STRING
: '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 473c91e69e..bc8cf4e78a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1424,7 +1424,7 @@ class Analyzer(
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
- case LeftSemi =>
+ case LeftExistence(_) =>
leftKeys ++ lUniqueOutput
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c085a377ff..f581810c26 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -361,8 +361,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
- // Eliminate unneeded attributes from right side of a LeftSemiJoin.
- case j @ Join(left, right, LeftSemi, condition) =>
+ // Eliminate unneeded attributes from right side of a Left Existence Join.
+ case j @ Join(left, right, LeftExistence(_), condition) =>
j.copy(right = prunedChild(right, j.references))
// all the columns will be used to compare, so we can't prune them
@@ -1126,7 +1126,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
- case _ @ (LeftOuter | LeftSemi) =>
+ case LeftOuter | LeftExistence(_) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -1147,7 +1147,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
- case _ @ (Inner | LeftSemi) =>
+ case Inner | LeftExistence(_) =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 5a3aebff09..aa59f3fb2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -572,6 +572,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case null => Inner
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.ANTI != null => LeftAnti
case jt if jt.LEFT != null => LeftOuter
case jt if jt.RIGHT != null => RightOuter
case _ => Inner
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 9ca4f13dd7..13f57c54a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -26,13 +26,15 @@ object JoinType {
case "leftouter" | "left" => LeftOuter
case "rightouter" | "right" => RightOuter
case "leftsemi" => LeftSemi
+ case "leftanti" => LeftAnti
case _ =>
val supported = Seq(
"inner",
"outer", "full", "fullouter",
"leftouter", "left",
"rightouter", "right",
- "leftsemi")
+ "leftsemi",
+ "leftanti")
throw new IllegalArgumentException(s"Unsupported join type '$typ'. " +
"Supported join types include: " + supported.mkString("'", "', '", "'") + ".")
@@ -63,6 +65,10 @@ case object LeftSemi extends JoinType {
override def sql: String = "LEFT SEMI"
}
+case object LeftAnti extends JoinType {
+ override def sql: String = "LEFT ANTI"
+}
+
case class NaturalJoin(tpe: JoinType) extends JoinType {
require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
"Unsupported natural join type " + tpe)
@@ -70,7 +76,14 @@ case class NaturalJoin(tpe: JoinType) extends JoinType {
}
case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType {
- require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe),
+ require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter, LeftAnti).contains(tpe),
"Unsupported using join type " + tpe)
override def sql: String = "USING " + tpe.sql
}
+
+object LeftExistence {
+ def unapply(joinType: JoinType): Option[JoinType] = joinType match {
+ case LeftSemi | LeftAnti => Some(joinType)
+ case _ => None
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index a18efc90ab..d3353beb09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -252,7 +252,7 @@ case class Join(
override def output: Seq[Attribute] = {
joinType match {
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -276,7 +276,7 @@ case class Join(
.union(splitConjunctivePredicates(condition.get).toSet)
case Inner =>
left.constraints.union(right.constraints)
- case LeftSemi =>
+ case LeftExistence(_) =>
left.constraints
case LeftOuter =>
left.constraints
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 262537d9c7..411e2372f2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -334,7 +334,7 @@ class PlanParserSuite extends PlanTest {
table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star()))
}
val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin)
-
+ val testExistence = Seq(testUnconditionalJoin, testConditionalJoin, testUsingJoin)
def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = {
tests.foreach(_(sql, jt))
}
@@ -348,6 +348,9 @@ class PlanParserSuite extends PlanTest {
test("right outer join", RightOuter, testAll)
test("full join", FullOuter, testAll)
test("full outer join", FullOuter, testAll)
+ test("left semi join", LeftSemi, testExistence)
+ test("left anti join", LeftAnti, testExistence)
+ test("anti join", LeftAnti, testExistence)
// Test multiple consecutive joins
assertEqual(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index ac8072f3ca..8d05ae470d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -38,7 +38,7 @@ class SparkPlanner(
DDLStrategy ::
SpecialLimits ::
Aggregation ::
- LeftSemiJoin ::
+ ExistenceJoin ::
EquiJoinSelection ::
InMemoryScans ::
BasicOperators ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index d77aba7260..eee2b946e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -62,16 +62,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object LeftSemiJoin extends Strategy with PredicateHelper {
+ object ExistenceJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(
- LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
+ LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
Seq(joins.BroadcastHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
- case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
+ case ExtractEquiJoinKeys(
+ LeftExistence(jt), leftKeys, rightKeys, condition, left, right) =>
Seq(joins.ShuffledHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 67ac9e94ff..e3d554c2de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.collection.CompactBuffer
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
@@ -87,6 +86,7 @@ case class BroadcastHashJoin(
case Inner => codegenInner(ctx, input)
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
+ case LeftAnti => codegenAnti(ctx, input)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
@@ -160,15 +160,14 @@ case class BroadcastHashJoin(
}
/**
- * Generates the code for Inner join.
+ * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi
+ * and Left Anti joins.
*/
- private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ private def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
val checkCondition = if (condition.isDefined) {
val expr = condition.get
// evaluate the variables from build side that used by condition
@@ -184,6 +183,17 @@ case class BroadcastHashJoin(
} else {
""
}
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
val resultVars = buildSide match {
case BuildLeft => buildVars ++ input
@@ -221,7 +231,6 @@ case class BroadcastHashJoin(
}
}
-
/**
* Generates the code for left or right outer join.
*/
@@ -276,7 +285,6 @@ case class BroadcastHashJoin(
ctx.copyResult = true
val matches = ctx.freshName("matches")
val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- val i = ctx.freshName("i")
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
@@ -304,26 +312,8 @@ case class BroadcastHashJoin(
private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
-
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
- // filter the output via condition
- ctx.currentVars = input ++ buildVars
- val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
- s"""
- |$eval
- |${ev.code}
- |if (${ev.isNull} || !${ev.value}) continue;
- """.stripMargin
- } else {
- ""
- }
-
if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
@@ -357,4 +347,57 @@ case class BroadcastHashJoin(
""".stripMargin
}
}
+
+ /**
+ * Generates the code for anti join.
+ */
+ private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (broadcastRelation.value.keyIsUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ | if ($matched != null) {
+ | // Evaluate the condition.
+ | $checkCondition
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value});
+ | if ($matches != null) {
+ | // Evaluate the condition.
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition
+ | $found = true;
+ | }
+ | if ($found) continue;
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 4143e944e5..4ba710c10a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -73,7 +73,7 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case x =>
throw new IllegalArgumentException(
@@ -175,8 +175,11 @@ case class BroadcastNestedLoopJoin(
* The implementation for these joins:
*
* LeftSemi with BuildRight
+ * Anti with BuildRight
*/
- private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ private def leftExistenceJoin(
+ relation: Broadcast[Array[InternalRow]],
+ exists: Boolean): RDD[InternalRow] = {
assert(buildSide == BuildRight)
streamed.execute().mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
@@ -184,10 +187,12 @@ case class BroadcastNestedLoopJoin(
if (condition.isDefined) {
streamedIter.filter(l =>
- buildRows.exists(r => boundCondition(joinedRow(l, r)))
+ buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
)
+ } else if (buildRows.nonEmpty == exists) {
+ streamedIter
} else {
- streamedIter.filter(r => !buildRows.isEmpty)
+ Iterator.empty
}
}
}
@@ -199,6 +204,7 @@ case class BroadcastNestedLoopJoin(
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
+ * Anti with BuildLeft
*/
private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
/** All rows that either match both-way, or rows from streamed joined with nulls. */
@@ -236,7 +242,27 @@ case class BroadcastNestedLoopJoin(
}
i += 1
}
- return sparkContext.makeRDD(buf.toSeq)
+ return sparkContext.makeRDD(buf)
+ }
+
+ val notMatchedBroadcastRows: Seq[InternalRow] = {
+ val nulls = new GenericMutableRow(streamed.output.size)
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ joinedRow.withLeft(nulls)
+ while (i < buildRows.length) {
+ if (!matchedBroadcastRows.get(i)) {
+ buf += joinedRow.withRight(buildRows(i)).copy()
+ }
+ i += 1
+ }
+ buf
+ }
+
+ if (joinType == LeftAnti) {
+ return sparkContext.makeRDD(notMatchedBroadcastRows)
}
val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
@@ -264,22 +290,6 @@ case class BroadcastNestedLoopJoin(
}
}
- val notMatchedBroadcastRows: Seq[InternalRow] = {
- val nulls = new GenericMutableRow(streamed.output.size)
- val buf: CompactBuffer[InternalRow] = new CompactBuffer()
- var i = 0
- val buildRows = relation.value
- val joinedRow = new JoinedRow
- joinedRow.withLeft(nulls)
- while (i < buildRows.length) {
- if (!matchedBroadcastRows.get(i)) {
- buf += joinedRow.withRight(buildRows(i)).copy()
- }
- i += 1
- }
- buf.toSeq
- }
-
sparkContext.union(
matchedStreamRows,
sparkContext.makeRDD(notMatchedBroadcastRows)
@@ -295,13 +305,16 @@ case class BroadcastNestedLoopJoin(
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
outerJoin(broadcastedRelation)
case (LeftSemi, BuildRight) =>
- leftSemiJoin(broadcastedRelation)
+ leftExistenceJoin(broadcastedRelation, exists = true)
+ case (LeftAnti, BuildRight) =>
+ leftExistenceJoin(broadcastedRelation, exists = false)
case _ =>
/**
* LeftOuter with BuildLeft
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
+ * Anti with BuildLeft
*/
defaultJoin(broadcastedRelation)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index b7c0f3e7d1..8f45d57126 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -47,7 +47,7 @@ trait HashJoin {
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case x =>
throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
@@ -197,6 +197,20 @@ trait HashJoin {
}
}
+ private def antiJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = streamSideKeyGenerator()
+ val joinedRow = new JoinedRow
+ streamIter.filter { current =>
+ val key = joinKeys(current)
+ lazy val buildIter = hashedRelation.get(key)
+ key.anyNull || buildIter == null || (condition.isDefined && !buildIter.exists {
+ row => boundCondition(joinedRow(current, row))
+ })
+ }
+ }
+
protected def join(
streamedIter: Iterator[InternalRow],
hashed: HashedRelation,
@@ -209,6 +223,8 @@ trait HashJoin {
outerJoin(streamedIter, hashed)
case LeftSemi =>
semiJoin(streamedIter, hashed)
+ case LeftAnti =>
+ antiJoin(streamedIter, hashed)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index c63faacf33..bf86096379 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -45,6 +45,7 @@ case class ShuffledHashJoin(
override def outputPartitioning: Partitioning = joinType match {
case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+ case LeftAnti => left.outputPartitioning
case LeftSemi => left.outputPartitioning
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index a5a4ff13de..a87a41c126 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -41,7 +41,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
assert(planned.size === 1)
}
- def assertJoin(sqlString: String, c: Class[_]): Any = {
+ def assertJoin(pair: (String, Class[_])): Any = {
+ val (sqlString, c) = pair
val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
@@ -53,8 +54,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
assert(operators.size === 1)
- if (operators(0).getClass() != c) {
- fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
+ if (operators.head.getClass != c) {
+ fail(s"$sqlString expected operator: $c, but got ${operators.head}\n physical: \n$physical")
}
}
@@ -93,8 +94,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
- classOf[BroadcastNestedLoopJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoin])
+ ).foreach(assertJoin)
}
}
@@ -114,7 +117,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a where key = 2",
classOf[BroadcastHashJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
sql("UNCACHE TABLE testData")
}
@@ -129,7 +132,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastHashJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[BroadcastHashJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
sql("UNCACHE TABLE testData")
}
@@ -419,25 +422,22 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(null, 10))
}
- test("broadcasted left semi join operator selection") {
+ test("broadcasted existence join operator selection") {
sqlContext.cacheManager.clearCache()
sql("CACHE TABLE testData")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[BroadcastHashJoin])
- ).foreach {
- case (query, joinClass) => assertJoin(query, joinClass)
- }
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData ANT JOIN testData2 ON key = a", classOf[BroadcastHashJoin])
+ ).foreach(assertJoin)
}
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin])
- ).foreach {
- case (query, joinClass) => assertJoin(query, joinClass)
- }
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin])
+ ).foreach(assertJoin)
}
sql("UNCACHE TABLE testData")
@@ -489,7 +489,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)",
classOf[BroadcastNestedLoopJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
checkAnswer(
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index 985a96f684..8cdfa8afd0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
@@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
-class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
+class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
private lazy val left = sqlContext.createDataFrame(
sparkContext.parallelize(Seq(
@@ -58,14 +58,20 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
LessThan(left.col("b").expr, right.col("d").expr))
}
+ private lazy val conditionNEQ = {
+ And((left.col("a") < right.col("c")).expr,
+ LessThan(left.col("b").expr, right.col("d").expr))
+ }
+
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
- private def testLeftSemiJoin(
+ private def testExistenceJoin(
testName: String,
+ joinType: JoinType,
leftRows: => DataFrame,
rightRows: => DataFrame,
condition: => Expression,
- expectedAnswer: Seq[Product]): Unit = {
+ expectedAnswer: Seq[Row]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
@@ -73,25 +79,26 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
}
test(s"$testName using ShuffledHashJoin") {
- extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext.sessionState.conf).apply(
ShuffledHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)),
- expectedAnswer.map(Row.fromTuple),
+ leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)),
+ expectedAnswer,
sortAnswers = true)
}
}
}
test(s"$testName using BroadcastHashJoin") {
- extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastHashJoin(
+ leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)),
+ expectedAnswer,
sortAnswers = true)
}
}
@@ -100,8 +107,9 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using BroadcastNestedLoopJoin build left") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition))),
+ expectedAnswer,
sortAnswers = true)
}
}
@@ -109,21 +117,43 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using BroadcastNestedLoopJoin build right") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition))),
+ expectedAnswer,
sortAnswers = true)
}
}
}
- testLeftSemiJoin(
- "basic test",
+ testExistenceJoin(
+ "basic test for left semi join",
+ LeftSemi,
+ left,
+ right,
+ condition,
+ Seq(Row(2, 1.0), Row(2, 1.0)))
+
+ testExistenceJoin(
+ "basic test for left semi non equal join",
+ LeftSemi,
+ left,
+ right,
+ conditionNEQ,
+ Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0)))
+
+ testExistenceJoin(
+ "basic test for anti join",
+ LeftAnti,
left,
right,
condition,
- Seq(
- (2, 1.0),
- (2, 1.0)
- )
- )
+ Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
+
+ testExistenceJoin(
+ "basic test for anti non equal join",
+ LeftAnti,
+ left,
+ right,
+ conditionNEQ,
+ Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index cff24e28fd..b992fda18c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -92,7 +92,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
DataSinks,
Scripts,
Aggregation,
- LeftSemiJoin,
+ ExistenceJoin,
EquiJoinSelection,
BasicOperators,
BroadcastNestedLoop,