aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-04-06 19:25:10 -0700
committerReynold Xin <rxin@databricks.com>2016-04-06 19:25:10 -0700
commitd76592276f9f66fed8012d876595de8717f516a9 (patch)
treebb3570eac8b6885efe77677d18cda30df7cb0a69 /sql/core
parent4901086fea969a34ec312ef4a8f83d84e1bf21fb (diff)
downloadspark-d76592276f9f66fed8012d876595de8717f516a9.tar.gz
spark-d76592276f9f66fed8012d876595de8717f516a9.tar.bz2
spark-d76592276f9f66fed8012d876595de8717f516a9.zip
[SPARK-12610][SQL] Left Anti Join
### What changes were proposed in this pull request? This PR adds support for `LEFT ANTI JOIN` to Spark SQL. A `LEFT ANTI JOIN` is the exact opposite of a `LEFT SEMI JOIN` and can be used to identify rows in one dataset that are not in another dataset. Note that `nulls` on the left side of the join cannot match a row on the right hand side of the join; the result is that left anti join will always select a row with a `null` in one or more of its keys. We currently add support for the following SQL join syntax: SELECT * FROM tbl1 A LEFT ANTI JOIN tbl2 B ON A.Id = B.Id Or using a dataframe: tbl1.as("a").join(tbl2.as("b"), $"a.id" === $"b.id", "left_anti) This PR provides serves as the basis for implementing `NOT EXISTS` and `NOT IN (...)` correlated sub-queries. It would also serve as good basis for implementing an more efficient `EXCEPT` operator. The PR has been (losely) based on PR's by both davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/10563); credit should be given where credit is due. This PR adds supports for `LEFT ANTI JOIN` to `BroadcastHashJoin` (including codegeneration), `ShuffledHashJoin` and `BroadcastNestedLoopJoin`. ### How was this patch tested? Added tests to `JoinSuite` and ported `ExistenceJoinSuite` from https://github.com/apache/spark/pull/10563. cc davies chenghao-intel rxin Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12214 from hvanhovell/SPARK-12610.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala99
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala57
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala (renamed from sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala)74
8 files changed, 201 insertions, 97 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index ac8072f3ca..8d05ae470d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -38,7 +38,7 @@ class SparkPlanner(
DDLStrategy ::
SpecialLimits ::
Aggregation ::
- LeftSemiJoin ::
+ ExistenceJoin ::
EquiJoinSelection ::
InMemoryScans ::
BasicOperators ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index d77aba7260..eee2b946e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -62,16 +62,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object LeftSemiJoin extends Strategy with PredicateHelper {
+ object ExistenceJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(
- LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
+ LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
Seq(joins.BroadcastHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
- case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
+ case ExtractEquiJoinKeys(
+ LeftExistence(jt), leftKeys, rightKeys, condition, left, right) =>
Seq(joins.ShuffledHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 67ac9e94ff..e3d554c2de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.collection.CompactBuffer
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
@@ -87,6 +86,7 @@ case class BroadcastHashJoin(
case Inner => codegenInner(ctx, input)
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
+ case LeftAnti => codegenAnti(ctx, input)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
@@ -160,15 +160,14 @@ case class BroadcastHashJoin(
}
/**
- * Generates the code for Inner join.
+ * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi
+ * and Left Anti joins.
*/
- private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ private def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
val checkCondition = if (condition.isDefined) {
val expr = condition.get
// evaluate the variables from build side that used by condition
@@ -184,6 +183,17 @@ case class BroadcastHashJoin(
} else {
""
}
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
val resultVars = buildSide match {
case BuildLeft => buildVars ++ input
@@ -221,7 +231,6 @@ case class BroadcastHashJoin(
}
}
-
/**
* Generates the code for left or right outer join.
*/
@@ -276,7 +285,6 @@ case class BroadcastHashJoin(
ctx.copyResult = true
val matches = ctx.freshName("matches")
val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- val i = ctx.freshName("i")
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
@@ -304,26 +312,8 @@ case class BroadcastHashJoin(
private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
-
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
- // filter the output via condition
- ctx.currentVars = input ++ buildVars
- val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
- s"""
- |$eval
- |${ev.code}
- |if (${ev.isNull} || !${ev.value}) continue;
- """.stripMargin
- } else {
- ""
- }
-
if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
@@ -357,4 +347,57 @@ case class BroadcastHashJoin(
""".stripMargin
}
}
+
+ /**
+ * Generates the code for anti join.
+ */
+ private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (broadcastRelation.value.keyIsUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ | if ($matched != null) {
+ | // Evaluate the condition.
+ | $checkCondition
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value});
+ | if ($matches != null) {
+ | // Evaluate the condition.
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition
+ | $found = true;
+ | }
+ | if ($found) continue;
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 4143e944e5..4ba710c10a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -73,7 +73,7 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case x =>
throw new IllegalArgumentException(
@@ -175,8 +175,11 @@ case class BroadcastNestedLoopJoin(
* The implementation for these joins:
*
* LeftSemi with BuildRight
+ * Anti with BuildRight
*/
- private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ private def leftExistenceJoin(
+ relation: Broadcast[Array[InternalRow]],
+ exists: Boolean): RDD[InternalRow] = {
assert(buildSide == BuildRight)
streamed.execute().mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
@@ -184,10 +187,12 @@ case class BroadcastNestedLoopJoin(
if (condition.isDefined) {
streamedIter.filter(l =>
- buildRows.exists(r => boundCondition(joinedRow(l, r)))
+ buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
)
+ } else if (buildRows.nonEmpty == exists) {
+ streamedIter
} else {
- streamedIter.filter(r => !buildRows.isEmpty)
+ Iterator.empty
}
}
}
@@ -199,6 +204,7 @@ case class BroadcastNestedLoopJoin(
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
+ * Anti with BuildLeft
*/
private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
/** All rows that either match both-way, or rows from streamed joined with nulls. */
@@ -236,7 +242,27 @@ case class BroadcastNestedLoopJoin(
}
i += 1
}
- return sparkContext.makeRDD(buf.toSeq)
+ return sparkContext.makeRDD(buf)
+ }
+
+ val notMatchedBroadcastRows: Seq[InternalRow] = {
+ val nulls = new GenericMutableRow(streamed.output.size)
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ joinedRow.withLeft(nulls)
+ while (i < buildRows.length) {
+ if (!matchedBroadcastRows.get(i)) {
+ buf += joinedRow.withRight(buildRows(i)).copy()
+ }
+ i += 1
+ }
+ buf
+ }
+
+ if (joinType == LeftAnti) {
+ return sparkContext.makeRDD(notMatchedBroadcastRows)
}
val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
@@ -264,22 +290,6 @@ case class BroadcastNestedLoopJoin(
}
}
- val notMatchedBroadcastRows: Seq[InternalRow] = {
- val nulls = new GenericMutableRow(streamed.output.size)
- val buf: CompactBuffer[InternalRow] = new CompactBuffer()
- var i = 0
- val buildRows = relation.value
- val joinedRow = new JoinedRow
- joinedRow.withLeft(nulls)
- while (i < buildRows.length) {
- if (!matchedBroadcastRows.get(i)) {
- buf += joinedRow.withRight(buildRows(i)).copy()
- }
- i += 1
- }
- buf.toSeq
- }
-
sparkContext.union(
matchedStreamRows,
sparkContext.makeRDD(notMatchedBroadcastRows)
@@ -295,13 +305,16 @@ case class BroadcastNestedLoopJoin(
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
outerJoin(broadcastedRelation)
case (LeftSemi, BuildRight) =>
- leftSemiJoin(broadcastedRelation)
+ leftExistenceJoin(broadcastedRelation, exists = true)
+ case (LeftAnti, BuildRight) =>
+ leftExistenceJoin(broadcastedRelation, exists = false)
case _ =>
/**
* LeftOuter with BuildLeft
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
+ * Anti with BuildLeft
*/
defaultJoin(broadcastedRelation)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index b7c0f3e7d1..8f45d57126 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -47,7 +47,7 @@ trait HashJoin {
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case x =>
throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
@@ -197,6 +197,20 @@ trait HashJoin {
}
}
+ private def antiJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = streamSideKeyGenerator()
+ val joinedRow = new JoinedRow
+ streamIter.filter { current =>
+ val key = joinKeys(current)
+ lazy val buildIter = hashedRelation.get(key)
+ key.anyNull || buildIter == null || (condition.isDefined && !buildIter.exists {
+ row => boundCondition(joinedRow(current, row))
+ })
+ }
+ }
+
protected def join(
streamedIter: Iterator[InternalRow],
hashed: HashedRelation,
@@ -209,6 +223,8 @@ trait HashJoin {
outerJoin(streamedIter, hashed)
case LeftSemi =>
semiJoin(streamedIter, hashed)
+ case LeftAnti =>
+ antiJoin(streamedIter, hashed)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index c63faacf33..bf86096379 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -45,6 +45,7 @@ case class ShuffledHashJoin(
override def outputPartitioning: Partitioning = joinType match {
case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+ case LeftAnti => left.outputPartitioning
case LeftSemi => left.outputPartitioning
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index a5a4ff13de..a87a41c126 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -41,7 +41,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
assert(planned.size === 1)
}
- def assertJoin(sqlString: String, c: Class[_]): Any = {
+ def assertJoin(pair: (String, Class[_])): Any = {
+ val (sqlString, c) = pair
val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
@@ -53,8 +54,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
assert(operators.size === 1)
- if (operators(0).getClass() != c) {
- fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
+ if (operators.head.getClass != c) {
+ fail(s"$sqlString expected operator: $c, but got ${operators.head}\n physical: \n$physical")
}
}
@@ -93,8 +94,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
- classOf[BroadcastNestedLoopJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoin])
+ ).foreach(assertJoin)
}
}
@@ -114,7 +117,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a where key = 2",
classOf[BroadcastHashJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
sql("UNCACHE TABLE testData")
}
@@ -129,7 +132,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastHashJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[BroadcastHashJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
sql("UNCACHE TABLE testData")
}
@@ -419,25 +422,22 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(null, 10))
}
- test("broadcasted left semi join operator selection") {
+ test("broadcasted existence join operator selection") {
sqlContext.cacheManager.clearCache()
sql("CACHE TABLE testData")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[BroadcastHashJoin])
- ).foreach {
- case (query, joinClass) => assertJoin(query, joinClass)
- }
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData ANT JOIN testData2 ON key = a", classOf[BroadcastHashJoin])
+ ).foreach(assertJoin)
}
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin])
- ).foreach {
- case (query, joinClass) => assertJoin(query, joinClass)
- }
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin])
+ ).foreach(assertJoin)
}
sql("UNCACHE TABLE testData")
@@ -489,7 +489,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)",
classOf[BroadcastNestedLoopJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
checkAnswer(
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index 985a96f684..8cdfa8afd0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
@@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
-class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
+class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
private lazy val left = sqlContext.createDataFrame(
sparkContext.parallelize(Seq(
@@ -58,14 +58,20 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
LessThan(left.col("b").expr, right.col("d").expr))
}
+ private lazy val conditionNEQ = {
+ And((left.col("a") < right.col("c")).expr,
+ LessThan(left.col("b").expr, right.col("d").expr))
+ }
+
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
- private def testLeftSemiJoin(
+ private def testExistenceJoin(
testName: String,
+ joinType: JoinType,
leftRows: => DataFrame,
rightRows: => DataFrame,
condition: => Expression,
- expectedAnswer: Seq[Product]): Unit = {
+ expectedAnswer: Seq[Row]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
@@ -73,25 +79,26 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
}
test(s"$testName using ShuffledHashJoin") {
- extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext.sessionState.conf).apply(
ShuffledHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)),
- expectedAnswer.map(Row.fromTuple),
+ leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)),
+ expectedAnswer,
sortAnswers = true)
}
}
}
test(s"$testName using BroadcastHashJoin") {
- extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastHashJoin(
+ leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)),
+ expectedAnswer,
sortAnswers = true)
}
}
@@ -100,8 +107,9 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using BroadcastNestedLoopJoin build left") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition))),
+ expectedAnswer,
sortAnswers = true)
}
}
@@ -109,21 +117,43 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using BroadcastNestedLoopJoin build right") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition))),
+ expectedAnswer,
sortAnswers = true)
}
}
}
- testLeftSemiJoin(
- "basic test",
+ testExistenceJoin(
+ "basic test for left semi join",
+ LeftSemi,
+ left,
+ right,
+ condition,
+ Seq(Row(2, 1.0), Row(2, 1.0)))
+
+ testExistenceJoin(
+ "basic test for left semi non equal join",
+ LeftSemi,
+ left,
+ right,
+ conditionNEQ,
+ Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0)))
+
+ testExistenceJoin(
+ "basic test for anti join",
+ LeftAnti,
left,
right,
condition,
- Seq(
- (2, 1.0),
- (2, 1.0)
- )
- )
+ Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
+
+ testExistenceJoin(
+ "basic test for anti non equal join",
+ LeftAnti,
+ left,
+ right,
+ conditionNEQ,
+ Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
}