aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-05-02 12:58:59 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-02 12:58:59 -0700
commit95e372141a102f933045fe9472bbe1ce8c91b5d5 (patch)
tree6d3388ab7342d548dc6a5106757b511ba849f448 /sql
parent6e6320122ea84247c67e2d0fb0e6af54e2c5bb31 (diff)
downloadspark-95e372141a102f933045fe9472bbe1ce8c91b5d5.tar.gz
spark-95e372141a102f933045fe9472bbe1ce8c91b5d5.tar.bz2
spark-95e372141a102f933045fe9472bbe1ce8c91b5d5.zip
[SPARK-14781] [SQL] support nested predicate subquery
## What changes were proposed in this pull request? In order to support nested predicate subquery, this PR introduce an internal join type ExistenceJoin, which will emit all the rows from left, plus an additional column, which presents there are any rows matched from right or not (it's not null-aware right now). This additional column could be used to replace the subquery in Filter. In theory, all the predicate subquery could use this join type, but it's slower than LeftSemi and LeftAnti, so it's only used for nested subquery (subquery inside OR). For example, the following SQL: ```sql SELECT a FROM t WHERE EXISTS (select 0) OR EXISTS (select 1) ``` This PR also fix a bug in predicate subquery push down through join (they should not). Nested null-aware subquery is still not supported. For example, `a > 3 OR b NOT IN (select bb from t)` After this, we could run TPCDS query Q10, Q35, Q45 ## How was this patch tested? Added unit tests. Author: Davies Liu <davies@databricks.com> Closes #12820 from davies/or_exists.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala66
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala94
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala50
14 files changed, 345 insertions, 61 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 61a7d9ea24..6e3a14dfb9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -115,8 +115,9 @@ trait CheckAnalysis extends PredicateHelper {
case f @ Filter(condition, child) =>
splitConjunctivePredicates(condition).foreach {
case _: PredicateSubquery | Not(_: PredicateSubquery) =>
- case e if PredicateSubquery.hasPredicateSubquery(e) =>
- failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e")
+ case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
+ failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" +
+ s" conditions: $e")
case e =>
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index cd6d3a00b7..eed062f8bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -92,7 +92,7 @@ case class PredicateSubquery(
extends SubqueryExpression with Predicate with Unevaluable {
override lazy val resolved = childrenResolved && query.resolved
override lazy val references: AttributeSet = super.references -- query.outputSet
- override def nullable: Boolean = false
+ override def nullable: Boolean = nullAware
override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
@@ -105,6 +105,19 @@ object PredicateSubquery {
case _ => false
}.isDefined
}
+
+ /**
+ * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could
+ * turn the null-aware predicate into not-null-aware predicate.
+ */
+ def hasNullAwarePredicateWithinNot(e: Expression): Boolean = {
+ e.find{ x =>
+ x.isInstanceOf[Not] && e.find {
+ case p: PredicateSubquery => p.nullAware
+ case _ => false
+ }.isDefined
+ }.isDefined
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a147fff274..e1c969f50f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -100,8 +100,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
- EliminateSerialization,
- RewritePredicateSubquery) ::
+ EliminateSerialization) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
@@ -109,7 +108,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("OptimizeCodegen", Once,
- OptimizeCodegen(conf)) :: Nil
+ OptimizeCodegen(conf)) ::
+ Batch("RewriteSubquery", Once,
+ RewritePredicateSubquery,
+ CollapseProject) :: Nil
}
/**
@@ -1078,7 +1080,14 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
- Join(input(0), input(1), Inner, conditions.reduceLeftOption(And))
+ val (joinConditions, others) = conditions.partition(
+ e => !PredicateSubquery.hasPredicateSubquery(e))
+ val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And))
+ if (others.nonEmpty) {
+ Filter(others.reduceLeft(And), join)
+ } else {
+ join
+ }
} else {
val left :: rest = input.toList
// find out the first join that have at least one join condition
@@ -1091,7 +1100,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
val right = conditionalJoin.getOrElse(rest.head)
val joinedRefs = left.outputSet ++ right.outputSet
- val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs))
+ val (joinConditions, others) = conditions.partition(
+ e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e))
val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And))
// should not have reference to same logical plan
@@ -1201,9 +1211,16 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
- val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
+ val (newJoinConditions, others) =
+ commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e))
+ val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
- Join(newLeft, newRight, Inner, newJoinCond)
+ val join = Join(newLeft, newRight, Inner, newJoinCond)
+ if (others.nonEmpty) {
+ Filter(others.reduceLeft(And), join)
+ } else {
+ join
+ }
case RightOuter =>
// push down the right side only `where` condition
val newLeft = left
@@ -1543,6 +1560,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
// if performance matters to you.
Join(p, sub, LeftAnti, Option(Or(anyNull, condition)))
+ case (p, predicate) =>
+ var joined = p
+ val replaced = predicate transformUp {
+ case PredicateSubquery(sub, conditions, nullAware, _) =>
+ // TODO: support null-aware join
+ val exists = AttributeReference("exists", BooleanType, false)()
+ joined = Join(joined, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
+ exists
+ }
+ Project(p.output, Filter(replaced, joined))
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 13f57c54a5..80674d9b4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.Attribute
object JoinType {
def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
@@ -69,6 +70,14 @@ case object LeftAnti extends JoinType {
override def sql: String = "LEFT ANTI"
}
+case class ExistenceJoin(exists: Attribute) extends JoinType {
+ override def sql: String = {
+ // This join type is only used in the end of optimizer and physical plans, we will not
+ // generate SQL for this join type
+ throw new UnsupportedOperationException
+ }
+}
+
case class NaturalJoin(tpe: JoinType) extends JoinType {
require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
"Unsupported natural join type " + tpe)
@@ -84,6 +93,7 @@ case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) exte
object LeftExistence {
def unapply(joinType: JoinType): Option[JoinType] = joinType match {
case LeftSemi | LeftAnti => Some(joinType)
+ case j: ExistenceJoin => Some(joinType)
case _ => None
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index b2297bbcaa..830a7ac77d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -273,6 +273,8 @@ case class Join(
override def output: Seq[Attribute] = {
joinType match {
+ case j: ExistenceJoin =>
+ left.output :+ j.exists
case LeftExistence(_) =>
left.output
case LeftOuter =>
@@ -295,6 +297,8 @@ case class Join(
case LeftSemi if condition.isDefined =>
left.constraints
.union(splitConjunctivePredicates(condition.get).toSet)
+ case j: ExistenceJoin =>
+ left.constraints
case Inner =>
left.constraints.union(right.constraints)
case LeftExistence(_) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 1b08913ddd..10bff3d6d8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -459,11 +459,14 @@ class AnalysisErrorSuite extends AnalysisTest {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", BooleanType)()
- val plan1 = Filter(Cast(In(a, Seq(ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a))
- assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
+ val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType),
+ LocalRelation(a))
+ assertAnalysisError(plan1,
+ "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil)
- val plan2 = Filter(Or(In(a, Seq(ListQuery(LocalRelation(b)))), c), LocalRelation(a, c))
- assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
+ val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c))
+ assertAnalysisError(plan2,
+ "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil)
}
test("PredicateSubQuery correlated predicate is nested in an illegal plan") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 238334e26b..9747e58f43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -118,6 +118,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
private def canBuildRight(joinType: JoinType): Boolean = joinType match {
case Inner | LeftOuter | LeftSemi | LeftAnti => true
+ case j: ExistenceJoin => true
case _ => false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 587c603192..7c194ab726 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -48,8 +48,6 @@ case class BroadcastHashJoinExec(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
- override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
-
override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
buildSide match {
@@ -85,6 +83,7 @@ case class BroadcastHashJoinExec(
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
case LeftAnti => codegenAnti(ctx, input)
+ case j: ExistenceJoin => codegenExistence(ctx, input)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
@@ -407,4 +406,67 @@ case class BroadcastHashJoinExec(
""".stripMargin
}
}
+
+ /**
+ * Generates the code for existence join.
+ */
+ private def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ val existsVar = ctx.freshName("exists")
+
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx)
+ s"""
+ |$eval
+ |${ev.code}
+ |$existsVar = !${ev.isNull} && ${ev.value};
+ """.stripMargin
+ } else {
+ s"$existsVar = true;"
+ }
+
+ val resultVar = input ++ Seq(ExprCode("", "false", existsVar))
+ if (broadcastRelation.value.keyIsUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |boolean $existsVar = false;
+ |if ($matched != null) {
+ | $checkCondition
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVar)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $existsVar = false;
+ |if ($matches != null) {
+ | while (!$existsVar && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVar)}
+ """.stripMargin
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index a659bf26e3..2a250ecce6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -50,19 +50,16 @@ case class BroadcastNestedLoopJoinExec(
UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
}
- private[this] def genResultProjection: InternalRow => InternalRow = {
- if (joinType == LeftSemi) {
+ private[this] def genResultProjection: InternalRow => InternalRow = joinType match {
+ case LeftExistence(j) =>
UnsafeProjection.create(output, output)
- } else {
+ case other =>
// Always put the stream side on left to simplify implementation
// both of left and right side could be null
UnsafeProjection.create(
output, (streamed.output ++ broadcast.output).map(_.withNullability(true)))
- }
}
- override def outputPartitioning: Partitioning = streamed.outputPartitioning
-
override def output: Seq[Attribute] = {
joinType match {
case Inner =>
@@ -73,6 +70,8 @@ case class BroadcastNestedLoopJoinExec(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case j: ExistenceJoin =>
+ left.output :+ j.exists
case LeftExistence(_) =>
left.output
case x =>
@@ -197,6 +196,28 @@ case class BroadcastNestedLoopJoinExec(
}
}
+ private def existenceJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ assert(buildSide == BuildRight)
+ streamed.execute().mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+
+ if (condition.isDefined) {
+ val resultRow = new GenericMutableRow(Array[Any](null))
+ streamedIter.map { row =>
+ val result = buildRows.exists(r => boundCondition(joinedRow(row, r)))
+ resultRow.setBoolean(0, result)
+ joinedRow(row, resultRow)
+ }
+ } else {
+ val resultRow = new GenericMutableRow(Array[Any](buildRows.nonEmpty))
+ streamedIter.map { row =>
+ joinedRow(row, resultRow)
+ }
+ }
+ }
+ }
+
/**
* The implementation for these joins:
*
@@ -204,7 +225,8 @@ case class BroadcastNestedLoopJoinExec(
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
- * Anti with BuildLeft
+ * LeftAnti with BuildLeft
+ * ExistenceJoin with BuildLeft
*/
private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
/** All rows that either match both-way, or rows from streamed joined with nulls. */
@@ -231,27 +253,50 @@ case class BroadcastNestedLoopJoinExec(
new BitSet(relation.value.length)
)(_ | _)
- if (joinType == LeftSemi) {
- assert(buildSide == BuildLeft)
- val buf: CompactBuffer[InternalRow] = new CompactBuffer()
- var i = 0
- val rel = relation.value
- while (i < rel.length) {
- if (matchedBroadcastRows.get(i)) {
- buf += rel(i).copy()
+ joinType match {
+ case LeftSemi =>
+ assert(buildSide == BuildLeft)
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val rel = relation.value
+ while (i < rel.length) {
+ if (matchedBroadcastRows.get(i)) {
+ buf += rel(i).copy()
+ }
+ i += 1
}
- i += 1
- }
- return sparkContext.makeRDD(buf)
+ return sparkContext.makeRDD(buf)
+ case j: ExistenceJoin =>
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val rel = relation.value
+ while (i < rel.length) {
+ val result = new GenericInternalRow(Array[Any](matchedBroadcastRows.get(i)))
+ buf += new JoinedRow(rel(i).copy(), result)
+ i += 1
+ }
+ return sparkContext.makeRDD(buf)
+ case LeftAnti =>
+ val notMatched: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val rel = relation.value
+ while (i < rel.length) {
+ if (!matchedBroadcastRows.get(i)) {
+ notMatched += rel(i).copy()
+ }
+ i += 1
+ }
+ return sparkContext.makeRDD(notMatched)
+ case o =>
}
val notMatchedBroadcastRows: Seq[InternalRow] = {
val nulls = new GenericMutableRow(streamed.output.size)
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
- var i = 0
- val buildRows = relation.value
val joinedRow = new JoinedRow
joinedRow.withLeft(nulls)
+ var i = 0
+ val buildRows = relation.value
while (i < buildRows.length) {
if (!matchedBroadcastRows.get(i)) {
buf += joinedRow.withRight(buildRows(i)).copy()
@@ -261,10 +306,6 @@ case class BroadcastNestedLoopJoinExec(
buf
}
- if (joinType == LeftAnti) {
- return sparkContext.makeRDD(notMatchedBroadcastRows)
- }
-
val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
val joinedRow = new JoinedRow
@@ -308,13 +349,16 @@ case class BroadcastNestedLoopJoinExec(
leftExistenceJoin(broadcastedRelation, exists = true)
case (LeftAnti, BuildRight) =>
leftExistenceJoin(broadcastedRelation, exists = false)
+ case (j: ExistenceJoin, BuildRight) =>
+ existenceJoin(broadcastedRelation)
case _ =>
/**
* LeftOuter with BuildLeft
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
- * Anti with BuildLeft
+ * LeftAnti with BuildLeft
+ * ExistenceJoin with BuildLeft
*/
defaultJoin(broadcastedRelation)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 9c173d7bf1..d46a80423f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{IntegralType, LongType}
@@ -43,6 +44,8 @@ trait HashJoin {
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
+ case j: ExistenceJoin =>
+ left.output :+ j.exists
case LeftExistence(_) =>
left.output
case x =>
@@ -50,6 +53,8 @@ trait HashJoin {
}
}
+ override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
+
protected lazy val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
@@ -110,15 +115,14 @@ trait HashJoin {
(r: InternalRow) => true
}
- protected def createResultProjection(): (InternalRow) => InternalRow = {
- if (joinType == LeftSemi) {
+ protected def createResultProjection(): (InternalRow) => InternalRow = joinType match {
+ case LeftExistence(_) =>
UnsafeProjection.create(output, output)
- } else {
+ case _ =>
// Always put the stream side on left to simplify implementation
// both of left and right side could be null
UnsafeProjection.create(
output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true)))
- }
}
private def innerJoin(
@@ -184,6 +188,23 @@ trait HashJoin {
}
}
+ private def existenceJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = streamSideKeyGenerator()
+ val result = new GenericMutableRow(Array[Any](null))
+ val joinedRow = new JoinedRow
+ streamIter.map { current =>
+ val key = joinKeys(current)
+ lazy val buildIter = hashedRelation.get(key)
+ val exists = !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists {
+ (row: InternalRow) => boundCondition(joinedRow(current, row))
+ })
+ result.setBoolean(0, exists)
+ joinedRow(current, result)
+ }
+ }
+
private def antiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {
@@ -212,6 +233,8 @@ trait HashJoin {
semiJoin(streamedIter, hashed)
case LeftAnti =>
antiJoin(streamedIter, hashed)
+ case j: ExistenceJoin =>
+ existenceJoin(streamedIter, hashed)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 3ef2fec352..0036f9aadc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
@@ -44,17 +44,6 @@ case class ShuffledHashJoinExec(
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
- override def outputPartitioning: Partitioning = joinType match {
- case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
- case LeftAnti => left.outputPartitioning
- case LeftSemi => left.outputPartitioning
- case LeftOuter => left.outputPartitioning
- case RightOuter => right.outputPartitioning
- case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
- case x =>
- throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType")
- }
-
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 775f8ac508..f0efa52c3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -53,6 +53,8 @@ case class SortMergeJoinExec(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
(left.output ++ right.output).map(_.withNullability(true))
+ case j: ExistenceJoin =>
+ left.output :+ j.exists
case LeftExistence(_) =>
left.output
case x =>
@@ -269,6 +271,44 @@ case class SortMergeJoinExec(
override def getRow: InternalRow = currentLeftRow
}.toScala
+ case j: ExistenceJoin =>
+ new RowIterator {
+ private[this] var currentLeftRow: InternalRow = _
+ private[this] val result: MutableRow = new GenericMutableRow(Array[Any](null))
+ private[this] val smjScanner = new SortMergeJoinScanner(
+ createLeftKeyGenerator(),
+ createRightKeyGenerator(),
+ keyOrdering,
+ RowIterator.fromScala(leftIter),
+ RowIterator.fromScala(rightIter)
+ )
+ private[this] val joinRow = new JoinedRow
+
+ override def advanceNext(): Boolean = {
+ while (smjScanner.findNextOuterJoinRows()) {
+ currentLeftRow = smjScanner.getStreamedRow
+ val currentRightMatches = smjScanner.getBufferedMatches
+ var found = false
+ if (currentRightMatches != null) {
+ var i = 0
+ while (!found && i < currentRightMatches.length) {
+ joinRow(currentLeftRow, currentRightMatches(i))
+ if (boundCondition(joinRow)) {
+ found = true
+ }
+ i += 1
+ }
+ }
+ result.setBoolean(0, found)
+ numOutputRows += 1
+ return true
+ }
+ false
+ }
+
+ override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result))
+ }.toScala
+
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin should not take $x as the JoinType")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 0bf4c6f960..ff3f9bb33f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -152,6 +152,19 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
}
+ test("EXISTS predicate subquery within OR") {
+ checkAnswer(
+ sql("select * from l where exists (select * from r where l.a = r.c)" +
+ " or exists (select * from r where l.a = r.c)"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)" +
+ " or not exists (select * from r where l.a = r.c)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) ::
+ Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+ }
+
test("IN predicate subquery") {
checkAnswer(
sql("select * from l where l.a in (select c from r)"),
@@ -187,6 +200,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
}
+ test("IN predicate subquery within OR") {
+ checkAnswer(
+ sql("select * from l where l.a in (select c from r)" +
+ " or l.a in (select c from r where l.b < r.d)"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+ intercept[AnalysisException] {
+ sql("select * from l where a not in (select c from r)" +
+ " or a not in (select c from r where c is not null)")
+ }
+ }
+
test("complex IN predicate subquery") {
checkAnswer(
sql("select * from l where (a, b) not in (select c, d from r)"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index b32b6444b6..8093054b6d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -18,15 +18,15 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
+import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructType}
class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
@@ -89,6 +89,18 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
ExtractEquiJoinKeys.unapply(join)
}
+ val existsAttr = AttributeReference("exists", BooleanType, false)()
+ val leftSemiPlus = ExistenceJoin(existsAttr)
+ def createLeftSemiPlusJoin(join: SparkPlan): SparkPlan = {
+ val output = join.output.dropRight(1)
+ val condition = if (joinType == LeftSemi) {
+ existsAttr
+ } else {
+ Not(existsAttr)
+ }
+ ProjectExec(output, FilterExec(condition, join))
+ }
+
test(s"$testName using ShuffledHashJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
@@ -98,6 +110,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)),
expectedAnswer,
sortAnswers = true)
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ createLeftSemiPlusJoin(ShuffledHashJoinExec(
+ leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, left, right))),
+ expectedAnswer,
+ sortAnswers = true)
}
}
}
@@ -111,6 +129,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)),
expectedAnswer,
sortAnswers = true)
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ createLeftSemiPlusJoin(BroadcastHashJoinExec(
+ leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, left, right))),
+ expectedAnswer,
+ sortAnswers = true)
}
}
}
@@ -123,6 +147,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer,
sortAnswers = true)
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ createLeftSemiPlusJoin(SortMergeJoinExec(
+ leftKeys, rightKeys, leftSemiPlus, boundCondition, left, right))),
+ expectedAnswer,
+ sortAnswers = true)
}
}
}
@@ -134,6 +164,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition))),
expectedAnswer,
sortAnswers = true)
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec(
+ left, right, BuildLeft, leftSemiPlus, Some(condition)))),
+ expectedAnswer,
+ sortAnswers = true)
}
}
@@ -144,6 +180,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition))),
expectedAnswer,
sortAnswers = true)
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec(
+ left, right, BuildRight, leftSemiPlus, Some(condition)))),
+ expectedAnswer,
+ sortAnswers = true)
}
}
}