From 83061be697f69f7e39deb9cda45742a323714231 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 29 Apr 2016 16:47:56 -0700 Subject: [SPARK-14858] [SQL] Enable subquery pushdown The previous subquery PRs did not include support for pushing subqueries used in filters (`WHERE`/`HAVING`) down. This PR adds this support. For example : ```scala range(0, 10).registerTempTable("a") range(5, 15).registerTempTable("b") range(7, 25).registerTempTable("c") range(3, 12).registerTempTable("d") val plan = sql("select * from a join b on a.id = b.id left join c on c.id = b.id where a.id in (select id from d)") plan.explain(true) ``` Leads to the following Analyzed & Optimized plans: ``` == Parsed Logical Plan == ... == Analyzed Logical Plan == id: bigint, id: bigint, id: bigint Project [id#0L,id#4L,id#8L] +- Filter predicate-subquery#16 [(id#0L = id#12L)] : +- SubqueryAlias predicate-subquery#16 [(id#0L = id#12L)] : +- Project [id#12L] : +- SubqueryAlias d : +- Range 3, 12, 1, 8, [id#12L] +- Join LeftOuter, Some((id#8L = id#4L)) :- Join Inner, Some((id#0L = id#4L)) : :- SubqueryAlias a : : +- Range 0, 10, 1, 8, [id#0L] : +- SubqueryAlias b : +- Range 5, 15, 1, 8, [id#4L] +- SubqueryAlias c +- Range 7, 25, 1, 8, [id#8L] == Optimized Logical Plan == Join LeftOuter, Some((id#8L = id#4L)) :- Join Inner, Some((id#0L = id#4L)) : :- Join LeftSemi, Some((id#0L = id#12L)) : : :- Range 0, 10, 1, 8, [id#0L] : : +- Range 3, 12, 1, 8, [id#12L] : +- Range 5, 15, 1, 8, [id#4L] +- Range 7, 25, 1, 8, [id#8L] == Physical Plan == ... ``` I have also taken the opportunity to move quite a bit of code around: - Rewriting subqueris and pulling out correlated predicated from subqueries has been moved into the analyzer. The analyzer transforms `Exists` and `InSubQuery` into `PredicateSubquery` expressions. A PredicateSubquery exposes the 'join' expressions and the proper references. This makes things like type coercion, optimization and planning easier to do. - I have added support for `Aggregate` plans in subqueries. Any correlated expressions will be added to the grouping expressions. I have removed support for `Union` plans, since pulling in an outer reference from beneath a Union has no value (a filtered value could easily be part of another Union child). - Resolution of subqueries is now done using `OuterReference`s. These are used to wrap any outer reference; this makes the identification of these references easier, and also makes dealing with duplicate attributes in the outer and inner plans easier. The resolution of subqueries initially used a resolution loop which would alternate between calling the analyzer and trying to resolve the outer references. We now use a dedicated analyzer which uses a special rule for outer reference resolution. These changes are a stepping stone for enabling correlated scalar subqueries, enabling all Hive tests & allowing us to use predicate subqueries anywhere. Current tests and added test cases in FilterPushdownSuite. Author: Herman van Hovell Closes #12720 from hvanhovell/SPARK-14858. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 305 ++++++++++++++++----- .../sql/catalyst/analysis/CheckAnalysis.scala | 36 +-- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 25 -- .../catalyst/expressions/namedExpressions.scala | 10 + .../spark/sql/catalyst/expressions/subquery.scala | 106 +++---- .../spark/sql/catalyst/optimizer/Optimizer.scala | 136 ++------- .../spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 7 +- .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 24 +- .../catalyst/optimizer/FilterPushdownSuite.scala | 39 ++- .../catalyst/parser/ExpressionParserSuite.scala | 2 +- .../apache/spark/sql/catalyst/plans/PlanTest.scala | 10 +- 12 files changed, 384 insertions(+), 318 deletions(-) (limited to 'sql/catalyst/src') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f6a65f7e6c..e98036a970 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.IntegerIndex import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.usePrettyExpression @@ -863,76 +863,246 @@ class Analyzer( } /** - * This rule resolves sub-queries inside expressions. + * This rule resolves and rewrites subqueries inside expressions. * * Note: CTEs are handled in CTESubstitution. */ object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { - /** - * Resolve the correlated predicates in the clauses (e.g. WHERE or HAVING) of a - * sub-query by using the plan the predicates should be correlated to. + * Resolve the correlated expressions in a subquery by using the an outer plans' references. All + * resolved outer references are wrapped in an [[OuterReference]] */ - private def resolveCorrelatedSubquery( - sub: LogicalPlan, outer: LogicalPlan, - aliases: scala.collection.mutable.Map[Attribute, Alias]): LogicalPlan = { - // First resolve as much of the sub-query as possible - val analyzed = execute(sub) - if (analyzed.resolved) { - analyzed - } else { - // Only resolve the lowest plan that is not resolved by outer plan, otherwise it could be - // resolved by itself - val resolvedByOuter = analyzed transformDown { - case q: LogicalPlan if q.childrenResolved && !q.resolved => - q transformExpressions { - case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { - try { - val outerAttrOpt = outer.resolve(nameParts, resolver) - if (outerAttrOpt.isDefined) { - val outerAttr = outerAttrOpt.get - if (q.inputSet.contains(outerAttr)) { - // Got a conflict, create an alias for the attribute come from outer table - val alias = Alias(outerAttr, outerAttr.toString)() - val attr = alias.toAttribute - aliases += attr -> alias - attr - } else { - outerAttr - } - } else { - u - } - } catch { - case a: AnalysisException => u + private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { + plan transformDown { + case q: LogicalPlan if q.childrenResolved && !q.resolved => + q transformExpressions { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { + try { + outer.resolve(nameParts, resolver) match { + case Some(outerAttr) => OuterReference(outerAttr) + case None => u } + } catch { + case _: AnalysisException => u } - } + } + } + } + } + + /** + * Pull out all (outer) correlated predicates from a given subquery. This method removes the + * correlated predicates from subquery [[Filter]]s and adds the references of these predicates + * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to + * be able to evaluate the predicates at the top level. + * + * This method returns the rewritten subquery and correlated predicates. + */ + private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + + /** Make sure a plans' subtree does not contain a tagged predicate. */ + def failOnOuterReferenceInSubTree(p: LogicalPlan, msg: String): Unit = { + if (p.collect(predicateMap).nonEmpty) { + failAnalysis(s"Accessing outer query column is not allowed in $msg: $p") } - if (resolvedByOuter fastEquals analyzed) { - analyzed - } else { - resolveCorrelatedSubquery(resolvedByOuter, outer, aliases) + } + + /** Helper function for locating outer references. */ + def containsOuter(e: Expression): Boolean = { + e.find(_.isInstanceOf[OuterReference]).isDefined + } + + /** Make sure a plans' expressions do not contain a tagged predicate. */ + def failOnOuterReference(p: LogicalPlan): Unit = { + if (p.expressions.exists(containsOuter)) { + failAnalysis( + s"Correlated predicates are not supported outside of WHERE/HAVING clauses: $p") } } + + /** Determine which correlated predicate references are missing from this plan. */ + def missingReferences(p: LogicalPlan): AttributeSet = { + val localPredicateReferences = p.collect(predicateMap) + .flatten + .map(_.references) + .reduceOption(_ ++ _) + .getOrElse(AttributeSet.empty) + localPredicateReferences -- p.outputSet + } + + val transformed = sub transformUp { + case f @ Filter(cond, child) => + // Find all predicates with an outer reference. + val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + + // Rewrite the filter without the correlated predicates if any. + correlated match { + case Nil => f + case xs if local.nonEmpty => + val newFilter = Filter(local.reduce(And), child) + predicateMap += newFilter -> xs + newFilter + case xs => + predicateMap += child -> xs + child + } + case p @ Project(expressions, child) => + failOnOuterReference(p) + val referencesToAdd = missingReferences(p) + if (referencesToAdd.nonEmpty) { + Project(expressions ++ referencesToAdd, child) + } else { + p + } + case a @ Aggregate(grouping, expressions, child) => + failOnOuterReference(a) + val referencesToAdd = missingReferences(a) + if (referencesToAdd.nonEmpty) { + Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) + } else { + a + } + case j @ Join(left, _, RightOuter, _) => + failOnOuterReference(j) + failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") + j + case j @ Join(_, right, jt, _) if jt != Inner => + failOnOuterReference(j) + failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN") + j + case u: Union => + failOnOuterReferenceInSubTree(u, "a UNION") + u + case s: SetOperation => + failOnOuterReferenceInSubTree(s.right, "an INTERSECT/EXCEPT") + s + case e: Expand => + failOnOuterReferenceInSubTree(e, "an EXPAND") + e + case p => + failOnOuterReference(p) + p + } + (transformed, predicateMap.values.flatten.toSeq) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Only a few unary node (Project/Filter/Aggregate/Having) could have subquery - case q: UnaryNode if q.childrenResolved => - val aliases = scala.collection.mutable.Map[Attribute, Alias]() - val newPlan = q transformExpressions { - case e: SubqueryExpression if !e.query.resolved => - e.withNewPlan(resolveCorrelatedSubquery(e.query, q.child, aliases)) + /** + * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same + * attributes. + */ + private def rewriteSubQuery( + sub: LogicalPlan, + outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + // Pull out the tagged predicates and rewrite the subquery in the process. + val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) + + // Make sure the inner and the outer query attributes do not collide. + val outputSet = outer.map(_.outputSet).reduce(_ ++ _) + val duplicates = basePlan.outputSet.intersect(outputSet) + val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = basePlan.output.map { ref => + aliasMap.getOrElse(ref, ref) } - if (aliases.nonEmpty) { - val projs = q.child.output ++ aliases.values - Project(q.child.output, - newPlan.withNewChildren(Seq(Project(projs, q.child)))) - } else { - newPlan + val aliasedProjection = Project(aliasedExpressions, basePlan) + val aliasedConditions = baseConditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (basePlan, baseConditions) + } + // Remove outer references from the correlated predicates. We wait with extracting + // these until collisions between the inner and outer query attributes have been + // solved. + val conditions = deDuplicatedConditions.map(_.transform { + case OuterReference(ref) => ref + }) + (plan, conditions) + } + + /** + * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method + * will resolve the subquery by alternating between the regular analyzer and by applying the + * resolveOuterReferences rule. + * + * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved. + */ + private def resolveSubQuery( + e: SubqueryExpression, + plans: Seq[LogicalPlan], + requiredColumns: Int = 0)( + f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { + // Step 1: Resolve the outer expressions. + var previous: LogicalPlan = null + var current = e.query + do { + // Try to resolve the subquery plan using the regular analyzer. + previous = current + current = execute(current) + + // Use the outer references to resolve the subquery plan if it isn't resolved yet. + val i = plans.iterator + val afterResolve = current + while (!current.resolved && current.fastEquals(afterResolve) && i.hasNext) { + current = resolveOuterReferences(current, i.next()) + } + } while (!current.resolved && !current.fastEquals(previous)) + + // Step 2: Pull out the predicates if the plan is resolved. + if (current.resolved) { + // Make sure the resolved query has the required number of output columns. This is only + // needed for IN expressions. + if (requiredColumns > 0 && requiredColumns != current.output.size) { + failAnalysis(s"The number of fields in the value ($requiredColumns) does not " + + s"match with the number of columns in the subquery (${current.output.size})") } + // Pullout predicates and construct a new plan. + f.tupled(rewriteSubQuery(current, plans)) + } else { + e.withNewPlan(current) + } + } + + /** + * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS + * expressions into PredicateSubquery expression once the are resolved. + */ + private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { + plan transformExpressions { + case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => + resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) + case e @ Exists(sub, exprId) => + resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) + case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => + // Get the left hand side expressions. + val expressions = e match { + case CreateStruct(exprs) => exprs + case expr => Seq(expr) + } + resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => + // Construct the IN conditions. + val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) + PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) + } + } + } + + /** + * Resolve and rewrite all subqueries in an operator tree.. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // In case of HAVING (a filter after an aggregate) we use both the aggregate and + // its child for resolution. + case f @ Filter(_, a: Aggregate) if f.childrenResolved => + resolveSubQueries(f, Seq(a, a.child)) + // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. + case q: UnaryNode if q.childrenResolved => + resolveSubQueries(q, q.children) } } @@ -986,12 +1156,24 @@ class Analyzer( // If resolution was successful and we see the filter has an aggregate in it, add it to // the original aggregate operator. - if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { - val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs + if (resolvedOperator.resolved) { + // Try to replace all aggregate expressions in the filter by an alias. + val aggregateExpressions = ArrayBuffer.empty[NamedExpression] + val transformedAggregateFilter = resolvedAggregateFilter.transform { + case ae: AggregateExpression => + val alias = Alias(ae, ae.toString)() + aggregateExpressions += alias + alias.toAttribute + } - Project(aggregate.output, - Filter(resolvedAggregateFilter.toAttribute, - aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + // Push the aggregate expressions into the aggregate (if any). + if (aggregateExpressions.nonEmpty) { + Project(aggregate.output, + Filter(transformedAggregateFilter, + aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) + } else { + filter + } } else { filter } @@ -1836,3 +2018,4 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 74f434e063..61a7d9ea24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -60,6 +60,9 @@ trait CheckAnalysis extends PredicateHelper { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") + case ScalarSubquery(_, conditions, _) if conditions.nonEmpty => + failAnalysis("Correlated scalar subqueries are not supported.") + case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => @@ -101,7 +104,6 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Window specification $s is not valid because $m") case None => w } - } operator match { @@ -111,38 +113,8 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${f.condition.dataType.simpleString} is not a boolean.") case f @ Filter(condition, child) => - // Make sure that no correlated reference is below Aggregates, Outer Joins and on the - // right hand side of Unions. - lazy val outerAttributes = child.outputSet - def failOnCorrelatedReference( - plan: LogicalPlan, - message: String): Unit = plan foreach { - case p => - lazy val inputs = p.inputSet - p.transformExpressions { - case e: AttributeReference - if !inputs.contains(e) && outerAttributes.contains(e) => - failAnalysis(s"Accessing outer query column is not allowed in $message: $e") - } - } - def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach { - case a @ Aggregate(_, _, source) => - failOnCorrelatedReference(source, "an AGGREGATE") - case j @ Join(left, _, RightOuter, _) => - failOnCorrelatedReference(left, "a RIGHT OUTER JOIN") - case j @ Join(_, right, jt, _) if jt != Inner => - failOnCorrelatedReference(right, "a LEFT (OUTER) JOIN") - case Union(_ :: xs) => - xs.foreach(failOnCorrelatedReference(_, "a UNION")) - case s: SetOperation => - failOnCorrelatedReference(s.right, "an INTERSECT/EXCEPT") - case _ => - } splitConjunctivePredicates(condition).foreach { - case p: PredicateSubquery => - checkForCorrelatedReferences(p) - case Not(p: PredicateSubquery) => - checkForCorrelatedReferences(p) + case _: PredicateSubquery | Not(_: PredicateSubquery) => case e if PredicateSubquery.hasPredicateSubquery(e) => failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e") case e => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 0306afb0d8..5323b79c57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -376,31 +376,6 @@ object HiveTypeCoercion { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } - - case InSubQuery(struct: CreateStruct, subquery, exprId) - if struct.children.zip(subquery.output).exists(x => x._1.dataType != x._2.dataType) => - val widerTypes: Seq[Option[DataType]] = struct.children.zip(subquery.output).map { - case (l, r) => findWiderTypeForTwo(l.dataType, r.dataType) - } - val newStruct = struct.withNewChildren(struct.children.zip(widerTypes).map { - case (e, Some(t)) => Cast(e, t) - case (e, _) => e - }) - val newSubquery = Project(subquery.output.zip(widerTypes).map { - case (a, Some(t)) => Alias(Cast(a, t), a.toString)() - case (a, _) => a - }, subquery) - InSubQuery(newStruct, newSubquery, exprId) - - case sub @ InSubQuery(expr, subquery, exprId) - if expr.dataType != subquery.output.head.dataType => - findWiderTypeForTwo(expr.dataType, subquery.output.head.dataType) match { - case Some(t) => - val attr = subquery.output.head - val proj = Seq(Alias(Cast(attr, t), attr.toString)()) - InSubQuery(Cast(expr, t), Project(proj, subquery), exprId) - case _ => sub - } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8b38838537..306a99d5a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -337,6 +337,16 @@ case class PrettyAttribute( override def nullable: Boolean = true } +/** + * A place holder used to hold a reference that has been resolved to a field outside of the current + * plan. This is used for correlated subqueries. + */ +case class OuterReference(e: NamedExpression) extends LeafExpression with Unevaluable { + override def dataType: DataType = e.dataType + override def nullable: Boolean = e.nullable + override def prettyName: String = "outer" +} + object VirtualColumn { // The attribute name used by Hive, which has different result than Spark, deprecated. val hiveGroupingIdName: String = "grouping__id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 1993bd2587..cd6d3a00b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.types._ * An interface for subquery that is used in expressions. */ abstract class SubqueryExpression extends Expression { + /** The id of the subquery expression. */ + def exprId: ExprId - /** - * The logical plan of the query. - */ + /** The logical plan of the query. */ def query: LogicalPlan /** @@ -38,31 +38,30 @@ abstract class SubqueryExpression extends Expression { */ def plan: QueryPlan[_] - /** - * Updates the query with new logical plan. - */ + /** Updates the query with new logical plan. */ def withNewPlan(plan: LogicalPlan): SubqueryExpression + + protected def conditionString: String = children.mkString("[", " && ", "]") } /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. * - * Note: `exprId` is used to have unique name in explain string output. + * Note: `exprId` is used to have a unique name in explain string output. */ case class ScalarSubquery( query: LogicalPlan, + children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Unevaluable { override def plan: LogicalPlan = SubqueryAlias(toString, query) - override lazy val resolved: Boolean = query.resolved + override lazy val resolved: Boolean = childrenResolved && query.resolved override def dataType: DataType = query.schema.fields.head.dataType - override def children: Seq[Expression] = Nil - override def checkInputDataTypes(): TypeCheckResult = { if (query.schema.length != 1) { TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + @@ -75,9 +74,9 @@ case class ScalarSubquery( override def foldable: Boolean = false override def nullable: Boolean = true - override def withNewPlan(plan: LogicalPlan): ScalarSubquery = ScalarSubquery(plan, exprId) + override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan) - override def toString: String = s"subquery#${exprId.id}" + override def toString: String = s"subquery#${exprId.id} $conditionString" } /** @@ -85,18 +84,34 @@ case class ScalarSubquery( * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will * be rewritten into a left semi/anti join during analysis. */ -abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate { +case class PredicateSubquery( + query: LogicalPlan, + children: Seq[Expression] = Seq.empty, + nullAware: Boolean = false, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression with Predicate with Unevaluable { + override lazy val resolved = childrenResolved && query.resolved + override lazy val references: AttributeSet = super.references -- query.outputSet override def nullable: Boolean = false + override def plan: LogicalPlan = SubqueryAlias(toString, query) + override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan) + override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" } object PredicateSubquery { def hasPredicateSubquery(e: Expression): Boolean = { - e.find(_.isInstanceOf[PredicateSubquery]).isDefined + e.find { + case _: PredicateSubquery | _: ListQuery | _: Exists => true + case _ => false + }.isDefined } } /** - * The [[InSubQuery]] predicate checks the existence of a value in a sub-query. For example (SQL): + * A [[ListQuery]] expression defines the query which we want to search in an IN subquery + * expression. It should and can only be used in conjunction with a IN expression. + * + * For example (SQL): * {{{ * SELECT * * FROM a @@ -104,47 +119,15 @@ object PredicateSubquery { * FROM b) * }}} */ -case class InSubQuery( - value: Expression, - query: LogicalPlan, - exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery { - override def children: Seq[Expression] = value :: Nil - override lazy val resolved: Boolean = value.resolved && query.resolved - override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan, exprId) - override def plan: LogicalPlan = SubqueryAlias(s"subquery#${exprId.id}", query) - - /** - * The unwrapped value side expressions. - */ - lazy val expressions: Seq[Expression] = value match { - case CreateStruct(cols) => cols - case col => Seq(col) - } - - /** - * Check if the number of columns and the data types on both sides match. - */ - override def checkInputDataTypes(): TypeCheckResult = { - // Check the number of arguments. - if (expressions.length != query.output.length) { - return TypeCheckResult.TypeCheckFailure( - s"The number of fields in the value (${expressions.length}) does not match with " + - s"the number of columns in the subquery (${query.output.length})") - } - - // Check the argument types. - expressions.zip(query.output).zipWithIndex.foreach { - case ((e, a), i) if e.dataType != a.dataType => - return TypeCheckResult.TypeCheckFailure( - s"The data type of value[$i] (${e.dataType}) does not match " + - s"subquery column '${a.name}' (${a.dataType}).") - case _ => - } - - TypeCheckResult.TypeCheckSuccess - } - - override def toString: String = s"$value IN subquery#${exprId.id}" +case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression with Unevaluable { + override lazy val resolved = false + override def children: Seq[Expression] = Seq.empty + override def dataType: DataType = ArrayType(NullType) + override def nullable: Boolean = false + override def withNewPlan(plan: LogicalPlan): ListQuery = copy(query = plan) + override def plan: LogicalPlan = SubqueryAlias(toString, query) + override def toString: String = s"list#${exprId.id}" } /** @@ -158,11 +141,12 @@ case class InSubQuery( * WHERE b.id = a.id) * }}} */ -case class Exists( - query: LogicalPlan, - exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery { - override def children: Seq[Expression] = Nil - override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan, exprId) +case class Exists(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression with Predicate with Unevaluable { + override lazy val resolved = false + override def children: Seq[Expression] = Seq.empty + override def nullable: Boolean = false + override def withNewPlan(plan: LogicalPlan): Exists = copy(query = plan) override def plan: LogicalPlan = SubqueryAlias(toString, query) override def toString: String = s"exists#${exprId.id}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index abbd8facd3..0b70edec8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet -import scala.collection.mutable import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{InSubQuery, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} @@ -48,7 +47,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // However, because we also use the analyzer to canonicalized queries (for view definition), // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, - RewritePredicateSubquery, EliminateSubqueryAliases, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), @@ -63,6 +61,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Subquery", Once, + OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, ReplaceIntersectWithSemiJoin, ReplaceExceptWithAntiJoin, @@ -99,15 +99,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, - EliminateSerialization) :: + EliminateSerialization, + RewritePredicateSubquery) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, EmbedSerializerInFilter) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation) :: - Batch("Subquery", Once, - OptimizeSubqueries) :: Batch("OptimizeCodegen", Once, OptimizeCodegen(conf)) :: Nil } @@ -117,8 +116,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) */ object OptimizeSubqueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case subquery: SubqueryExpression => - subquery.withNewPlan(Optimizer.this.execute(subquery.query)) + case s: SubqueryExpression => + s.withNewPlan(Optimizer.this.execute(s.query)) } } } @@ -636,7 +635,8 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe // Only consider constraints that can be pushed down completely to either the left or the // right child val constraints = join.constraints.filter { c => - c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)} + c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) + } // Remove those constraints that are already enforced by either the left or the right child val additionalConstraints = constraints -- (left.constraints ++ right.constraints) val newConditionOpt = conditionOpt match { @@ -1123,7 +1123,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { * Returns whether the expression returns null or false when all inputs are nulls. */ private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic) return false + if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val v = BindReferences.bindReference(e, attributes).eval(emptyRow) @@ -1503,94 +1503,6 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] { * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { - /** - * Pull out all correlated predicates from a given sub-query. This method removes the correlated - * predicates from sub-query [[Filter]]s and adds the references of these predicates to - * all intermediate [[Project]] clauses (if they are missing) in order to be able to evaluate the - * predicates in the join condition. - * - * This method returns the rewritten sub-query and the combined (AND) extracted predicate. - */ - private def pullOutCorrelatedPredicates( - subquery: LogicalPlan, - query: LogicalPlan): (LogicalPlan, Seq[Expression]) = { - val references = query.outputSet - val predicateMap = mutable.Map.empty[LogicalPlan, Seq[Expression]] - val transformed = subquery transformUp { - case f @ Filter(cond, child) => - // Find all correlated predicates. - val (correlated, local) = splitConjunctivePredicates(cond).partition { e => - (e.references -- child.outputSet).intersect(references).nonEmpty - } - // Rewrite the filter without the correlated predicates if any. - correlated match { - case Nil => f - case xs if local.nonEmpty => - val newFilter = Filter(local.reduce(And), child) - predicateMap += newFilter -> correlated - newFilter - case xs => - predicateMap += child -> correlated - child - } - case p @ Project(expressions, child) => - // Find all pulled out predicates defined in the Project's subtree. - val localPredicates = p.collect(predicateMap).flatten - - // Determine which correlated predicate references are missing from this project. - val localPredicateReferences = localPredicates - .map(_.references) - .reduceOption(_ ++ _) - .getOrElse(AttributeSet.empty) - val missingReferences = localPredicateReferences -- p.references -- query.outputSet - - // Create a new project if we need to add missing references. - if (missingReferences.nonEmpty) { - Project(expressions ++ missingReferences, child) - } else { - p - } - } - (transformed, predicateMap.values.flatten.toSeq) - } - - /** - * Prepare an [[InSubQuery]] by rewriting it (in case of correlated predicates) and by - * constructing the required join condition. Both the rewritten subquery and the constructed - * join condition are returned. - */ - private def pullOutCorrelatedPredicates( - in: InSubQuery, - query: LogicalPlan): (LogicalPlan, LogicalPlan, Seq[Expression]) = { - val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query) - // Check whether there is some attributes have same exprId but come from different side - val outerAttributes = AttributeSet(in.expressions.flatMap(_.references)) - if (outerAttributes.intersect(resolved.outputSet).nonEmpty) { - val aliases = mutable.Map[Attribute, Alias]() - val exprs = in.expressions.map { expr => - expr transformUp { - case a: AttributeReference if resolved.outputSet.contains(a) => - val alias = Alias(a, a.toString)() - val attr = alias.toAttribute - aliases += attr -> alias - attr - } - } - val newP = Project(query.output ++ aliases.values, query) - val projection = resolved.output.map { - case a if outerAttributes.contains(a) => Alias(a, a.toString)() - case a => a - } - val subquery = Project(projection, resolved) - val conditions = joinCondition ++ exprs.zip(subquery.output).map(EqualTo.tupled) - (newP, subquery, conditions) - } else { - val conditions = - joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled) - (query, resolved, conditions) - } - } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(condition, child) => val (withSubquery, withoutSubquery) = @@ -1604,22 +1516,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, Exists(sub, _)) => - val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) - Join(p, resolved, LeftSemi, conditions.reduceOption(And)) - case (p, Not(Exists(sub, _))) => - val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) - Join(p, resolved, LeftAnti, conditions.reduceOption(And)) - case (p, in: InSubQuery) => - val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p) - if (newP fastEquals p) { - Join(p, resolved, LeftSemi, conditions.reduceOption(And)) - } else { - Project(p.output, - Join(newP, resolved, LeftSemi, conditions.reduceOption(And))) - } - case (p, Not(in: InSubQuery)) => - val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p) + case (p, PredicateSubquery(sub, conditions, _, _)) => + Join(p, sub, LeftSemi, conditions.reduceOption(And)) + case (p, Not(PredicateSubquery(sub, conditions, false, _))) => + Join(p, sub, LeftAnti, conditions.reduceOption(And)) + case (p, Not(PredicateSubquery(sub, conditions, true, _))) => // This is a NULL-aware (left) anti join (NAAJ). // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -1628,12 +1529,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS // if performance matters to you. - if (newP fastEquals p) { - Join(p, resolved, LeftAnti, Option(Or(anyNull, condition))) - } else { - Project(p.output, - Join(newP, resolved, LeftAnti, Option(Or(anyNull, condition)))) - } + Join(p, sub, LeftAnti, Option(Or(anyNull, condition))) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7f98c21af2..1f923f47dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -956,7 +956,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(InSubQuery(e, plan(ctx.query))) + invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query))))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b358e210da..b2297bbcaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -107,8 +107,11 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows - override protected def validConstraints: Set[Expression] = - child.constraints.union(splitConjunctivePredicates(condition).toSet) + override protected def validConstraints: Set[Expression] = { + val predicates = splitConjunctivePredicates(condition) + .filterNot(PredicateSubquery.hasPredicateSubquery) + child.constraints.union(predicates.toSet) + } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index a90636d278..1b08913ddd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} @@ -449,7 +450,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(InSubQuery(a, LocalRelation(b)), "c")()), + Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -458,10 +459,10 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(InSubQuery(a, LocalRelation(b)), BooleanType), LocalRelation(a)) + val plan1 = Filter(Cast(In(a, Seq(ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(InSubQuery(a, LocalRelation(b)), c), LocalRelation(a, c)) + val plan2 = Filter(Or(In(a, Seq(ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil) } @@ -474,7 +475,7 @@ class AnalysisErrorSuite extends AnalysisTest { Exists( Join( LocalRelation(b), - Filter(EqualTo(a, c), LocalRelation(c)), + Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), LeftOuter, Option(EqualTo(b, c)))), LocalRelation(a)) @@ -483,7 +484,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Filter( Exists( Join( - Filter(EqualTo(a, c), LocalRelation(c)), + Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), LocalRelation(b), RightOuter, Option(EqualTo(b, c)))), @@ -491,13 +492,16 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) val plan3 = Filter( - Exists(Aggregate(Seq.empty, Seq.empty, Filter(EqualTo(a, c), LocalRelation(c)))), + Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) + } - val plan4 = Filter( - Exists(Union(LocalRelation(b), Filter(EqualTo(a, c), LocalRelation(c)))), - LocalRelation(a)) - assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) + test("Correlated Scalar Subquery") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) + val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), LocalRelation(a)) + assertAnalysisError(plan, "Correlated scalar subqueries are not supported." :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index e9b4bb002b..fcc14a803b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{LeftOuter, LeftSemi, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.IntegerType @@ -725,6 +725,43 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctedAnswer) } + test("predicate subquery: push down simple") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z) + + val query = x + .join(y, Inner, Option("x.a".attr === "y.a".attr)) + .where(Exists(z.where("x.a".attr === "z.a".attr))) + .analyze + val answer = x + .where(Exists(z.where("x.a".attr === "z.a".attr))) + .join(y, Inner, Option("x.a".attr === "y.a".attr)) + .analyze + val optimized = Optimize.execute(Optimize.execute(query)) + comparePlans(optimized, answer) + } + + test("predicate subquery: push down complex") { + val w = testRelation.subquery('w) + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z) + + val query = w + .join(x, Inner, Option("w.a".attr === "x.a".attr)) + .join(y, LeftOuter, Option("x.a".attr === "y.a".attr)) + .where(Exists(z.where("w.a".attr === "z.a".attr))) + .analyze + val answer = w + .where(Exists(z.where("w.a".attr === "z.a".attr))) + .join(x, Inner, Option("w.a".attr === "x.a".attr)) + .join(y, LeftOuter, Option("x.a".attr === "y.a".attr)) + .analyze + val optimized = Optimize.execute(Optimize.execute(query)) + comparePlans(optimized, answer) + } + test("Window: predicate push down -- basic") { val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 5af3ea9c7a..e73592c7af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -146,7 +146,7 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - InSubQuery('a, table("c").select('b))) + In('a, Seq(ListQuery(table("c").select('b))))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index f5439d70ad..6310f0c2bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -34,11 +34,13 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { protected def normalizeExprIds(plan: LogicalPlan) = { plan transformAllExpressions { case s: ScalarSubquery => - ScalarSubquery(s.query, ExprId(0)) - case s: InSubQuery => - InSubQuery(s.value, s.query, ExprId(0)) + s.copy(exprId = ExprId(0)) case e: Exists => - Exists(e.query, ExprId(0)) + e.copy(exprId = ExprId(0)) + case l: ListQuery => + l.copy(exprId = ExprId(0)) + case p: PredicateSubquery => + p.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => -- cgit v1.2.3