aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-04-29 16:47:56 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-29 16:50:12 -0700
commit83061be697f69f7e39deb9cda45742a323714231 (patch)
treeb20fad5198d43bbdb83ecf0ea80cd0834f45bcb7 /sql/catalyst/src
parent1eda2f10d9f7add319e5b271488045c44ea30c03 (diff)
downloadspark-83061be697f69f7e39deb9cda45742a323714231.tar.gz
spark-83061be697f69f7e39deb9cda45742a323714231.tar.bz2
spark-83061be697f69f7e39deb9cda45742a323714231.zip
[SPARK-14858] [SQL] Enable subquery pushdown
The previous subquery PRs did not include support for pushing subqueries used in filters (`WHERE`/`HAVING`) down. This PR adds this support. For example : ```scala range(0, 10).registerTempTable("a") range(5, 15).registerTempTable("b") range(7, 25).registerTempTable("c") range(3, 12).registerTempTable("d") val plan = sql("select * from a join b on a.id = b.id left join c on c.id = b.id where a.id in (select id from d)") plan.explain(true) ``` Leads to the following Analyzed & Optimized plans: ``` == Parsed Logical Plan == ... == Analyzed Logical Plan == id: bigint, id: bigint, id: bigint Project [id#0L,id#4L,id#8L] +- Filter predicate-subquery#16 [(id#0L = id#12L)] : +- SubqueryAlias predicate-subquery#16 [(id#0L = id#12L)] : +- Project [id#12L] : +- SubqueryAlias d : +- Range 3, 12, 1, 8, [id#12L] +- Join LeftOuter, Some((id#8L = id#4L)) :- Join Inner, Some((id#0L = id#4L)) : :- SubqueryAlias a : : +- Range 0, 10, 1, 8, [id#0L] : +- SubqueryAlias b : +- Range 5, 15, 1, 8, [id#4L] +- SubqueryAlias c +- Range 7, 25, 1, 8, [id#8L] == Optimized Logical Plan == Join LeftOuter, Some((id#8L = id#4L)) :- Join Inner, Some((id#0L = id#4L)) : :- Join LeftSemi, Some((id#0L = id#12L)) : : :- Range 0, 10, 1, 8, [id#0L] : : +- Range 3, 12, 1, 8, [id#12L] : +- Range 5, 15, 1, 8, [id#4L] +- Range 7, 25, 1, 8, [id#8L] == Physical Plan == ... ``` I have also taken the opportunity to move quite a bit of code around: - Rewriting subqueris and pulling out correlated predicated from subqueries has been moved into the analyzer. The analyzer transforms `Exists` and `InSubQuery` into `PredicateSubquery` expressions. A PredicateSubquery exposes the 'join' expressions and the proper references. This makes things like type coercion, optimization and planning easier to do. - I have added support for `Aggregate` plans in subqueries. Any correlated expressions will be added to the grouping expressions. I have removed support for `Union` plans, since pulling in an outer reference from beneath a Union has no value (a filtered value could easily be part of another Union child). - Resolution of subqueries is now done using `OuterReference`s. These are used to wrap any outer reference; this makes the identification of these references easier, and also makes dealing with duplicate attributes in the outer and inner plans easier. The resolution of subqueries initially used a resolution loop which would alternate between calling the analyzer and trying to resolve the outer references. We now use a dedicated analyzer which uses a special rule for outer reference resolution. These changes are a stepping stone for enabling correlated scalar subqueries, enabling all Hive tests & allowing us to use predicate subqueries anywhere. Current tests and added test cases in FilterPushdownSuite. Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12720 from hvanhovell/SPARK-14858.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala305
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala106
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala136
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala39
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala10
12 files changed, 384 insertions, 318 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f6a65f7e6c..e98036a970 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.IntegerIndex
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.usePrettyExpression
@@ -863,76 +863,246 @@ class Analyzer(
}
/**
- * This rule resolves sub-queries inside expressions.
+ * This rule resolves and rewrites subqueries inside expressions.
*
* Note: CTEs are handled in CTESubstitution.
*/
object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
-
/**
- * Resolve the correlated predicates in the clauses (e.g. WHERE or HAVING) of a
- * sub-query by using the plan the predicates should be correlated to.
+ * Resolve the correlated expressions in a subquery by using the an outer plans' references. All
+ * resolved outer references are wrapped in an [[OuterReference]]
*/
- private def resolveCorrelatedSubquery(
- sub: LogicalPlan, outer: LogicalPlan,
- aliases: scala.collection.mutable.Map[Attribute, Alias]): LogicalPlan = {
- // First resolve as much of the sub-query as possible
- val analyzed = execute(sub)
- if (analyzed.resolved) {
- analyzed
- } else {
- // Only resolve the lowest plan that is not resolved by outer plan, otherwise it could be
- // resolved by itself
- val resolvedByOuter = analyzed transformDown {
- case q: LogicalPlan if q.childrenResolved && !q.resolved =>
- q transformExpressions {
- case u @ UnresolvedAttribute(nameParts) =>
- withPosition(u) {
- try {
- val outerAttrOpt = outer.resolve(nameParts, resolver)
- if (outerAttrOpt.isDefined) {
- val outerAttr = outerAttrOpt.get
- if (q.inputSet.contains(outerAttr)) {
- // Got a conflict, create an alias for the attribute come from outer table
- val alias = Alias(outerAttr, outerAttr.toString)()
- val attr = alias.toAttribute
- aliases += attr -> alias
- attr
- } else {
- outerAttr
- }
- } else {
- u
- }
- } catch {
- case a: AnalysisException => u
+ private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
+ plan transformDown {
+ case q: LogicalPlan if q.childrenResolved && !q.resolved =>
+ q transformExpressions {
+ case u @ UnresolvedAttribute(nameParts) =>
+ withPosition(u) {
+ try {
+ outer.resolve(nameParts, resolver) match {
+ case Some(outerAttr) => OuterReference(outerAttr)
+ case None => u
}
+ } catch {
+ case _: AnalysisException => u
}
- }
+ }
+ }
+ }
+ }
+
+ /**
+ * Pull out all (outer) correlated predicates from a given subquery. This method removes the
+ * correlated predicates from subquery [[Filter]]s and adds the references of these predicates
+ * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to
+ * be able to evaluate the predicates at the top level.
+ *
+ * This method returns the rewritten subquery and correlated predicates.
+ */
+ private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+ val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
+
+ /** Make sure a plans' subtree does not contain a tagged predicate. */
+ def failOnOuterReferenceInSubTree(p: LogicalPlan, msg: String): Unit = {
+ if (p.collect(predicateMap).nonEmpty) {
+ failAnalysis(s"Accessing outer query column is not allowed in $msg: $p")
}
- if (resolvedByOuter fastEquals analyzed) {
- analyzed
- } else {
- resolveCorrelatedSubquery(resolvedByOuter, outer, aliases)
+ }
+
+ /** Helper function for locating outer references. */
+ def containsOuter(e: Expression): Boolean = {
+ e.find(_.isInstanceOf[OuterReference]).isDefined
+ }
+
+ /** Make sure a plans' expressions do not contain a tagged predicate. */
+ def failOnOuterReference(p: LogicalPlan): Unit = {
+ if (p.expressions.exists(containsOuter)) {
+ failAnalysis(
+ s"Correlated predicates are not supported outside of WHERE/HAVING clauses: $p")
}
}
+
+ /** Determine which correlated predicate references are missing from this plan. */
+ def missingReferences(p: LogicalPlan): AttributeSet = {
+ val localPredicateReferences = p.collect(predicateMap)
+ .flatten
+ .map(_.references)
+ .reduceOption(_ ++ _)
+ .getOrElse(AttributeSet.empty)
+ localPredicateReferences -- p.outputSet
+ }
+
+ val transformed = sub transformUp {
+ case f @ Filter(cond, child) =>
+ // Find all predicates with an outer reference.
+ val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter)
+
+ // Rewrite the filter without the correlated predicates if any.
+ correlated match {
+ case Nil => f
+ case xs if local.nonEmpty =>
+ val newFilter = Filter(local.reduce(And), child)
+ predicateMap += newFilter -> xs
+ newFilter
+ case xs =>
+ predicateMap += child -> xs
+ child
+ }
+ case p @ Project(expressions, child) =>
+ failOnOuterReference(p)
+ val referencesToAdd = missingReferences(p)
+ if (referencesToAdd.nonEmpty) {
+ Project(expressions ++ referencesToAdd, child)
+ } else {
+ p
+ }
+ case a @ Aggregate(grouping, expressions, child) =>
+ failOnOuterReference(a)
+ val referencesToAdd = missingReferences(a)
+ if (referencesToAdd.nonEmpty) {
+ Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
+ } else {
+ a
+ }
+ case j @ Join(left, _, RightOuter, _) =>
+ failOnOuterReference(j)
+ failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN")
+ j
+ case j @ Join(_, right, jt, _) if jt != Inner =>
+ failOnOuterReference(j)
+ failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN")
+ j
+ case u: Union =>
+ failOnOuterReferenceInSubTree(u, "a UNION")
+ u
+ case s: SetOperation =>
+ failOnOuterReferenceInSubTree(s.right, "an INTERSECT/EXCEPT")
+ s
+ case e: Expand =>
+ failOnOuterReferenceInSubTree(e, "an EXPAND")
+ e
+ case p =>
+ failOnOuterReference(p)
+ p
+ }
+ (transformed, predicateMap.values.flatten.toSeq)
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- // Only a few unary node (Project/Filter/Aggregate/Having) could have subquery
- case q: UnaryNode if q.childrenResolved =>
- val aliases = scala.collection.mutable.Map[Attribute, Alias]()
- val newPlan = q transformExpressions {
- case e: SubqueryExpression if !e.query.resolved =>
- e.withNewPlan(resolveCorrelatedSubquery(e.query, q.child, aliases))
+ /**
+ * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same
+ * attributes.
+ */
+ private def rewriteSubQuery(
+ sub: LogicalPlan,
+ outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
+ // Pull out the tagged predicates and rewrite the subquery in the process.
+ val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub)
+
+ // Make sure the inner and the outer query attributes do not collide.
+ val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
+ val duplicates = basePlan.outputSet.intersect(outputSet)
+ val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) {
+ val aliasMap = AttributeMap(duplicates.map { dup =>
+ dup -> Alias(dup, dup.toString)()
+ }.toSeq)
+ val aliasedExpressions = basePlan.output.map { ref =>
+ aliasMap.getOrElse(ref, ref)
}
- if (aliases.nonEmpty) {
- val projs = q.child.output ++ aliases.values
- Project(q.child.output,
- newPlan.withNewChildren(Seq(Project(projs, q.child))))
- } else {
- newPlan
+ val aliasedProjection = Project(aliasedExpressions, basePlan)
+ val aliasedConditions = baseConditions.map(_.transform {
+ case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
+ })
+ (aliasedProjection, aliasedConditions)
+ } else {
+ (basePlan, baseConditions)
+ }
+ // Remove outer references from the correlated predicates. We wait with extracting
+ // these until collisions between the inner and outer query attributes have been
+ // solved.
+ val conditions = deDuplicatedConditions.map(_.transform {
+ case OuterReference(ref) => ref
+ })
+ (plan, conditions)
+ }
+
+ /**
+ * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method
+ * will resolve the subquery by alternating between the regular analyzer and by applying the
+ * resolveOuterReferences rule.
+ *
+ * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved.
+ */
+ private def resolveSubQuery(
+ e: SubqueryExpression,
+ plans: Seq[LogicalPlan],
+ requiredColumns: Int = 0)(
+ f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
+ // Step 1: Resolve the outer expressions.
+ var previous: LogicalPlan = null
+ var current = e.query
+ do {
+ // Try to resolve the subquery plan using the regular analyzer.
+ previous = current
+ current = execute(current)
+
+ // Use the outer references to resolve the subquery plan if it isn't resolved yet.
+ val i = plans.iterator
+ val afterResolve = current
+ while (!current.resolved && current.fastEquals(afterResolve) && i.hasNext) {
+ current = resolveOuterReferences(current, i.next())
+ }
+ } while (!current.resolved && !current.fastEquals(previous))
+
+ // Step 2: Pull out the predicates if the plan is resolved.
+ if (current.resolved) {
+ // Make sure the resolved query has the required number of output columns. This is only
+ // needed for IN expressions.
+ if (requiredColumns > 0 && requiredColumns != current.output.size) {
+ failAnalysis(s"The number of fields in the value ($requiredColumns) does not " +
+ s"match with the number of columns in the subquery (${current.output.size})")
}
+ // Pullout predicates and construct a new plan.
+ f.tupled(rewriteSubQuery(current, plans))
+ } else {
+ e.withNewPlan(current)
+ }
+ }
+
+ /**
+ * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS
+ * expressions into PredicateSubquery expression once the are resolved.
+ */
+ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
+ plan transformExpressions {
+ case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
+ resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
+ case e @ Exists(sub, exprId) =>
+ resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId))
+ case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
+ // Get the left hand side expressions.
+ val expressions = e match {
+ case CreateStruct(exprs) => exprs
+ case expr => Seq(expr)
+ }
+ resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) =>
+ // Construct the IN conditions.
+ val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled)
+ PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId)
+ }
+ }
+ }
+
+ /**
+ * Resolve and rewrite all subqueries in an operator tree..
+ */
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ // In case of HAVING (a filter after an aggregate) we use both the aggregate and
+ // its child for resolution.
+ case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
+ resolveSubQueries(f, Seq(a, a.child))
+ // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
+ case q: UnaryNode if q.childrenResolved =>
+ resolveSubQueries(q, q.children)
}
}
@@ -986,12 +1156,24 @@ class Analyzer(
// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
- if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
- val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
+ if (resolvedOperator.resolved) {
+ // Try to replace all aggregate expressions in the filter by an alias.
+ val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
+ val transformedAggregateFilter = resolvedAggregateFilter.transform {
+ case ae: AggregateExpression =>
+ val alias = Alias(ae, ae.toString)()
+ aggregateExpressions += alias
+ alias.toAttribute
+ }
- Project(aggregate.output,
- Filter(resolvedAggregateFilter.toAttribute,
- aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+ // Push the aggregate expressions into the aggregate (if any).
+ if (aggregateExpressions.nonEmpty) {
+ Project(aggregate.output,
+ Filter(transformedAggregateFilter,
+ aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
+ } else {
+ filter
+ }
} else {
filter
}
@@ -1836,3 +2018,4 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}
}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 74f434e063..61a7d9ea24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -60,6 +60,9 @@ trait CheckAnalysis extends PredicateHelper {
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")
+ case ScalarSubquery(_, conditions, _) if conditions.nonEmpty =>
+ failAnalysis("Correlated scalar subqueries are not supported.")
+
case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
@@ -101,7 +104,6 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis(s"Window specification $s is not valid because $m")
case None => w
}
-
}
operator match {
@@ -111,38 +113,8 @@ trait CheckAnalysis extends PredicateHelper {
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
case f @ Filter(condition, child) =>
- // Make sure that no correlated reference is below Aggregates, Outer Joins and on the
- // right hand side of Unions.
- lazy val outerAttributes = child.outputSet
- def failOnCorrelatedReference(
- plan: LogicalPlan,
- message: String): Unit = plan foreach {
- case p =>
- lazy val inputs = p.inputSet
- p.transformExpressions {
- case e: AttributeReference
- if !inputs.contains(e) && outerAttributes.contains(e) =>
- failAnalysis(s"Accessing outer query column is not allowed in $message: $e")
- }
- }
- def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach {
- case a @ Aggregate(_, _, source) =>
- failOnCorrelatedReference(source, "an AGGREGATE")
- case j @ Join(left, _, RightOuter, _) =>
- failOnCorrelatedReference(left, "a RIGHT OUTER JOIN")
- case j @ Join(_, right, jt, _) if jt != Inner =>
- failOnCorrelatedReference(right, "a LEFT (OUTER) JOIN")
- case Union(_ :: xs) =>
- xs.foreach(failOnCorrelatedReference(_, "a UNION"))
- case s: SetOperation =>
- failOnCorrelatedReference(s.right, "an INTERSECT/EXCEPT")
- case _ =>
- }
splitConjunctivePredicates(condition).foreach {
- case p: PredicateSubquery =>
- checkForCorrelatedReferences(p)
- case Not(p: PredicateSubquery) =>
- checkForCorrelatedReferences(p)
+ case _: PredicateSubquery | Not(_: PredicateSubquery) =>
case e if PredicateSubquery.hasPredicateSubquery(e) =>
failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e")
case e =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 0306afb0d8..5323b79c57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -376,31 +376,6 @@ object HiveTypeCoercion {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
}
-
- case InSubQuery(struct: CreateStruct, subquery, exprId)
- if struct.children.zip(subquery.output).exists(x => x._1.dataType != x._2.dataType) =>
- val widerTypes: Seq[Option[DataType]] = struct.children.zip(subquery.output).map {
- case (l, r) => findWiderTypeForTwo(l.dataType, r.dataType)
- }
- val newStruct = struct.withNewChildren(struct.children.zip(widerTypes).map {
- case (e, Some(t)) => Cast(e, t)
- case (e, _) => e
- })
- val newSubquery = Project(subquery.output.zip(widerTypes).map {
- case (a, Some(t)) => Alias(Cast(a, t), a.toString)()
- case (a, _) => a
- }, subquery)
- InSubQuery(newStruct, newSubquery, exprId)
-
- case sub @ InSubQuery(expr, subquery, exprId)
- if expr.dataType != subquery.output.head.dataType =>
- findWiderTypeForTwo(expr.dataType, subquery.output.head.dataType) match {
- case Some(t) =>
- val attr = subquery.output.head
- val proj = Seq(Alias(Cast(attr, t), attr.toString)())
- InSubQuery(Cast(expr, t), Project(proj, subquery), exprId)
- case _ => sub
- }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 8b38838537..306a99d5a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -337,6 +337,16 @@ case class PrettyAttribute(
override def nullable: Boolean = true
}
+/**
+ * A place holder used to hold a reference that has been resolved to a field outside of the current
+ * plan. This is used for correlated subqueries.
+ */
+case class OuterReference(e: NamedExpression) extends LeafExpression with Unevaluable {
+ override def dataType: DataType = e.dataType
+ override def nullable: Boolean = e.nullable
+ override def prettyName: String = "outer"
+}
+
object VirtualColumn {
// The attribute name used by Hive, which has different result than Spark, deprecated.
val hiveGroupingIdName: String = "grouping__id"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 1993bd2587..cd6d3a00b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -26,10 +26,10 @@ import org.apache.spark.sql.types._
* An interface for subquery that is used in expressions.
*/
abstract class SubqueryExpression extends Expression {
+ /** The id of the subquery expression. */
+ def exprId: ExprId
- /**
- * The logical plan of the query.
- */
+ /** The logical plan of the query. */
def query: LogicalPlan
/**
@@ -38,31 +38,30 @@ abstract class SubqueryExpression extends Expression {
*/
def plan: QueryPlan[_]
- /**
- * Updates the query with new logical plan.
- */
+ /** Updates the query with new logical plan. */
def withNewPlan(plan: LogicalPlan): SubqueryExpression
+
+ protected def conditionString: String = children.mkString("[", " && ", "]")
}
/**
* A subquery that will return only one row and one column. This will be converted into a physical
* scalar subquery during planning.
*
- * Note: `exprId` is used to have unique name in explain string output.
+ * Note: `exprId` is used to have a unique name in explain string output.
*/
case class ScalarSubquery(
query: LogicalPlan,
+ children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {
override def plan: LogicalPlan = SubqueryAlias(toString, query)
- override lazy val resolved: Boolean = query.resolved
+ override lazy val resolved: Boolean = childrenResolved && query.resolved
override def dataType: DataType = query.schema.fields.head.dataType
- override def children: Seq[Expression] = Nil
-
override def checkInputDataTypes(): TypeCheckResult = {
if (query.schema.length != 1) {
TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
@@ -75,9 +74,9 @@ case class ScalarSubquery(
override def foldable: Boolean = false
override def nullable: Boolean = true
- override def withNewPlan(plan: LogicalPlan): ScalarSubquery = ScalarSubquery(plan, exprId)
+ override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan)
- override def toString: String = s"subquery#${exprId.id}"
+ override def toString: String = s"subquery#${exprId.id} $conditionString"
}
/**
@@ -85,18 +84,34 @@ case class ScalarSubquery(
* [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will
* be rewritten into a left semi/anti join during analysis.
*/
-abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate {
+case class PredicateSubquery(
+ query: LogicalPlan,
+ children: Seq[Expression] = Seq.empty,
+ nullAware: Boolean = false,
+ exprId: ExprId = NamedExpression.newExprId)
+ extends SubqueryExpression with Predicate with Unevaluable {
+ override lazy val resolved = childrenResolved && query.resolved
+ override lazy val references: AttributeSet = super.references -- query.outputSet
override def nullable: Boolean = false
+ override def plan: LogicalPlan = SubqueryAlias(toString, query)
+ override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
+ override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
}
object PredicateSubquery {
def hasPredicateSubquery(e: Expression): Boolean = {
- e.find(_.isInstanceOf[PredicateSubquery]).isDefined
+ e.find {
+ case _: PredicateSubquery | _: ListQuery | _: Exists => true
+ case _ => false
+ }.isDefined
}
}
/**
- * The [[InSubQuery]] predicate checks the existence of a value in a sub-query. For example (SQL):
+ * A [[ListQuery]] expression defines the query which we want to search in an IN subquery
+ * expression. It should and can only be used in conjunction with a IN expression.
+ *
+ * For example (SQL):
* {{{
* SELECT *
* FROM a
@@ -104,47 +119,15 @@ object PredicateSubquery {
* FROM b)
* }}}
*/
-case class InSubQuery(
- value: Expression,
- query: LogicalPlan,
- exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery {
- override def children: Seq[Expression] = value :: Nil
- override lazy val resolved: Boolean = value.resolved && query.resolved
- override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan, exprId)
- override def plan: LogicalPlan = SubqueryAlias(s"subquery#${exprId.id}", query)
-
- /**
- * The unwrapped value side expressions.
- */
- lazy val expressions: Seq[Expression] = value match {
- case CreateStruct(cols) => cols
- case col => Seq(col)
- }
-
- /**
- * Check if the number of columns and the data types on both sides match.
- */
- override def checkInputDataTypes(): TypeCheckResult = {
- // Check the number of arguments.
- if (expressions.length != query.output.length) {
- return TypeCheckResult.TypeCheckFailure(
- s"The number of fields in the value (${expressions.length}) does not match with " +
- s"the number of columns in the subquery (${query.output.length})")
- }
-
- // Check the argument types.
- expressions.zip(query.output).zipWithIndex.foreach {
- case ((e, a), i) if e.dataType != a.dataType =>
- return TypeCheckResult.TypeCheckFailure(
- s"The data type of value[$i] (${e.dataType}) does not match " +
- s"subquery column '${a.name}' (${a.dataType}).")
- case _ =>
- }
-
- TypeCheckResult.TypeCheckSuccess
- }
-
- override def toString: String = s"$value IN subquery#${exprId.id}"
+case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
+ extends SubqueryExpression with Unevaluable {
+ override lazy val resolved = false
+ override def children: Seq[Expression] = Seq.empty
+ override def dataType: DataType = ArrayType(NullType)
+ override def nullable: Boolean = false
+ override def withNewPlan(plan: LogicalPlan): ListQuery = copy(query = plan)
+ override def plan: LogicalPlan = SubqueryAlias(toString, query)
+ override def toString: String = s"list#${exprId.id}"
}
/**
@@ -158,11 +141,12 @@ case class InSubQuery(
* WHERE b.id = a.id)
* }}}
*/
-case class Exists(
- query: LogicalPlan,
- exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery {
- override def children: Seq[Expression] = Nil
- override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan, exprId)
+case class Exists(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
+ extends SubqueryExpression with Predicate with Unevaluable {
+ override lazy val resolved = false
+ override def children: Seq[Expression] = Seq.empty
+ override def nullable: Boolean = false
+ override def withNewPlan(plan: LogicalPlan): Exists = copy(query = plan)
override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def toString: String = s"exists#${exprId.id}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index abbd8facd3..0b70edec8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.annotation.tailrec
import scala.collection.immutable.HashSet
-import scala.collection.mutable
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions.{InSubQuery, _}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
@@ -48,7 +47,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
// However, because we also use the analyzer to canonicalized queries (for view definition),
// we do not eliminate subqueries or compute current time in the analyzer.
Batch("Finish Analysis", Once,
- RewritePredicateSubquery,
EliminateSubqueryAliases,
ComputeCurrentTime,
GetCurrentDatabase(sessionCatalog),
@@ -63,6 +61,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
// since the other rules might make two separate Unions operators adjacent.
Batch("Union", Once,
CombineUnions) ::
+ Batch("Subquery", Once,
+ OptimizeSubqueries) ::
Batch("Replace Operators", fixedPoint,
ReplaceIntersectWithSemiJoin,
ReplaceExceptWithAntiJoin,
@@ -99,15 +99,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
- EliminateSerialization) ::
+ EliminateSerialization,
+ RewritePredicateSubquery) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
EmbedSerializerInFilter) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
- Batch("Subquery", Once,
- OptimizeSubqueries) ::
Batch("OptimizeCodegen", Once,
OptimizeCodegen(conf)) :: Nil
}
@@ -117,8 +116,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
*/
object OptimizeSubqueries extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case subquery: SubqueryExpression =>
- subquery.withNewPlan(Optimizer.this.execute(subquery.query))
+ case s: SubqueryExpression =>
+ s.withNewPlan(Optimizer.this.execute(s.query))
}
}
}
@@ -636,7 +635,8 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
// Only consider constraints that can be pushed down completely to either the left or the
// right child
val constraints = join.constraints.filter { c =>
- c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)}
+ c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)
+ }
// Remove those constraints that are already enforced by either the left or the right child
val additionalConstraints = constraints -- (left.constraints ++ right.constraints)
val newConditionOpt = conditionOpt match {
@@ -1123,7 +1123,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper {
* Returns whether the expression returns null or false when all inputs are nulls.
*/
private def canFilterOutNull(e: Expression): Boolean = {
- if (!e.deterministic) return false
+ if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false
val attributes = e.references.toSeq
val emptyRow = new GenericInternalRow(attributes.length)
val v = BindReferences.bindReference(e, attributes).eval(emptyRow)
@@ -1503,94 +1503,6 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] {
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
- /**
- * Pull out all correlated predicates from a given sub-query. This method removes the correlated
- * predicates from sub-query [[Filter]]s and adds the references of these predicates to
- * all intermediate [[Project]] clauses (if they are missing) in order to be able to evaluate the
- * predicates in the join condition.
- *
- * This method returns the rewritten sub-query and the combined (AND) extracted predicate.
- */
- private def pullOutCorrelatedPredicates(
- subquery: LogicalPlan,
- query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
- val references = query.outputSet
- val predicateMap = mutable.Map.empty[LogicalPlan, Seq[Expression]]
- val transformed = subquery transformUp {
- case f @ Filter(cond, child) =>
- // Find all correlated predicates.
- val (correlated, local) = splitConjunctivePredicates(cond).partition { e =>
- (e.references -- child.outputSet).intersect(references).nonEmpty
- }
- // Rewrite the filter without the correlated predicates if any.
- correlated match {
- case Nil => f
- case xs if local.nonEmpty =>
- val newFilter = Filter(local.reduce(And), child)
- predicateMap += newFilter -> correlated
- newFilter
- case xs =>
- predicateMap += child -> correlated
- child
- }
- case p @ Project(expressions, child) =>
- // Find all pulled out predicates defined in the Project's subtree.
- val localPredicates = p.collect(predicateMap).flatten
-
- // Determine which correlated predicate references are missing from this project.
- val localPredicateReferences = localPredicates
- .map(_.references)
- .reduceOption(_ ++ _)
- .getOrElse(AttributeSet.empty)
- val missingReferences = localPredicateReferences -- p.references -- query.outputSet
-
- // Create a new project if we need to add missing references.
- if (missingReferences.nonEmpty) {
- Project(expressions ++ missingReferences, child)
- } else {
- p
- }
- }
- (transformed, predicateMap.values.flatten.toSeq)
- }
-
- /**
- * Prepare an [[InSubQuery]] by rewriting it (in case of correlated predicates) and by
- * constructing the required join condition. Both the rewritten subquery and the constructed
- * join condition are returned.
- */
- private def pullOutCorrelatedPredicates(
- in: InSubQuery,
- query: LogicalPlan): (LogicalPlan, LogicalPlan, Seq[Expression]) = {
- val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query)
- // Check whether there is some attributes have same exprId but come from different side
- val outerAttributes = AttributeSet(in.expressions.flatMap(_.references))
- if (outerAttributes.intersect(resolved.outputSet).nonEmpty) {
- val aliases = mutable.Map[Attribute, Alias]()
- val exprs = in.expressions.map { expr =>
- expr transformUp {
- case a: AttributeReference if resolved.outputSet.contains(a) =>
- val alias = Alias(a, a.toString)()
- val attr = alias.toAttribute
- aliases += attr -> alias
- attr
- }
- }
- val newP = Project(query.output ++ aliases.values, query)
- val projection = resolved.output.map {
- case a if outerAttributes.contains(a) => Alias(a, a.toString)()
- case a => a
- }
- val subquery = Project(projection, resolved)
- val conditions = joinCondition ++ exprs.zip(subquery.output).map(EqualTo.tupled)
- (newP, subquery, conditions)
- } else {
- val conditions =
- joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled)
- (query, resolved, conditions)
- }
- }
-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
@@ -1604,22 +1516,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
- case (p, Exists(sub, _)) =>
- val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
- Join(p, resolved, LeftSemi, conditions.reduceOption(And))
- case (p, Not(Exists(sub, _))) =>
- val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
- Join(p, resolved, LeftAnti, conditions.reduceOption(And))
- case (p, in: InSubQuery) =>
- val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p)
- if (newP fastEquals p) {
- Join(p, resolved, LeftSemi, conditions.reduceOption(And))
- } else {
- Project(p.output,
- Join(newP, resolved, LeftSemi, conditions.reduceOption(And)))
- }
- case (p, Not(in: InSubQuery)) =>
- val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+ case (p, PredicateSubquery(sub, conditions, _, _)) =>
+ Join(p, sub, LeftSemi, conditions.reduceOption(And))
+ case (p, Not(PredicateSubquery(sub, conditions, false, _))) =>
+ Join(p, sub, LeftAnti, conditions.reduceOption(And))
+ case (p, Not(PredicateSubquery(sub, conditions, true, _))) =>
// This is a NULL-aware (left) anti join (NAAJ).
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
@@ -1628,12 +1529,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
// if performance matters to you.
- if (newP fastEquals p) {
- Join(p, resolved, LeftAnti, Option(Or(anyNull, condition)))
- } else {
- Project(p.output,
- Join(newP, resolved, LeftAnti, Option(Or(anyNull, condition))))
- }
+ Join(p, sub, LeftAnti, Option(Or(anyNull, condition)))
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 7f98c21af2..1f923f47dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -956,7 +956,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
GreaterThanOrEqual(e, expression(ctx.lower)),
LessThanOrEqual(e, expression(ctx.upper))))
case SqlBaseParser.IN if ctx.query != null =>
- invertIfNotDefined(InSubQuery(e, plan(ctx.query)))
+ invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query)))))
case SqlBaseParser.IN =>
invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
case SqlBaseParser.LIKE =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index b358e210da..b2297bbcaa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -107,8 +107,11 @@ case class Filter(condition: Expression, child: LogicalPlan)
override def maxRows: Option[Long] = child.maxRows
- override protected def validConstraints: Set[Expression] =
- child.constraints.union(splitConjunctivePredicates(condition).toSet)
+ override protected def validConstraints: Set[Expression] = {
+ val predicates = splitConjunctivePredicates(condition)
+ .filterNot(PredicateSubquery.hasPredicateSubquery)
+ child.constraints.union(predicates.toSet)
+ }
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index a90636d278..1b08913ddd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
@@ -449,7 +450,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val plan = Project(
- Seq(a, Alias(InSubQuery(a, LocalRelation(b)), "c")()),
+ Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()),
LocalRelation(a))
assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil)
}
@@ -458,10 +459,10 @@ class AnalysisErrorSuite extends AnalysisTest {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", BooleanType)()
- val plan1 = Filter(Cast(InSubQuery(a, LocalRelation(b)), BooleanType), LocalRelation(a))
+ val plan1 = Filter(Cast(In(a, Seq(ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a))
assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
- val plan2 = Filter(Or(InSubQuery(a, LocalRelation(b)), c), LocalRelation(a, c))
+ val plan2 = Filter(Or(In(a, Seq(ListQuery(LocalRelation(b)))), c), LocalRelation(a, c))
assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
}
@@ -474,7 +475,7 @@ class AnalysisErrorSuite extends AnalysisTest {
Exists(
Join(
LocalRelation(b),
- Filter(EqualTo(a, c), LocalRelation(c)),
+ Filter(EqualTo(OuterReference(a), c), LocalRelation(c)),
LeftOuter,
Option(EqualTo(b, c)))),
LocalRelation(a))
@@ -483,7 +484,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan2 = Filter(
Exists(
Join(
- Filter(EqualTo(a, c), LocalRelation(c)),
+ Filter(EqualTo(OuterReference(a), c), LocalRelation(c)),
LocalRelation(b),
RightOuter,
Option(EqualTo(b, c)))),
@@ -491,13 +492,16 @@ class AnalysisErrorSuite extends AnalysisTest {
assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil)
val plan3 = Filter(
- Exists(Aggregate(Seq.empty, Seq.empty, Filter(EqualTo(a, c), LocalRelation(c)))),
+ Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))),
LocalRelation(a))
assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
+ }
- val plan4 = Filter(
- Exists(Union(LocalRelation(b), Filter(EqualTo(a, c), LocalRelation(c)))),
- LocalRelation(a))
- assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil)
+ test("Correlated Scalar Subquery") {
+ val a = AttributeReference("a", IntegerType)()
+ val b = AttributeReference("b", IntegerType)()
+ val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
+ val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), LocalRelation(a))
+ assertAnalysisError(plan, "Correlated scalar subqueries are not supported." :: Nil)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index e9b4bb002b..fcc14a803b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{LeftOuter, LeftSemi, PlanTest, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.IntegerType
@@ -725,6 +725,43 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctedAnswer)
}
+ test("predicate subquery: push down simple") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z)
+
+ val query = x
+ .join(y, Inner, Option("x.a".attr === "y.a".attr))
+ .where(Exists(z.where("x.a".attr === "z.a".attr)))
+ .analyze
+ val answer = x
+ .where(Exists(z.where("x.a".attr === "z.a".attr)))
+ .join(y, Inner, Option("x.a".attr === "y.a".attr))
+ .analyze
+ val optimized = Optimize.execute(Optimize.execute(query))
+ comparePlans(optimized, answer)
+ }
+
+ test("predicate subquery: push down complex") {
+ val w = testRelation.subquery('w)
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z)
+
+ val query = w
+ .join(x, Inner, Option("w.a".attr === "x.a".attr))
+ .join(y, LeftOuter, Option("x.a".attr === "y.a".attr))
+ .where(Exists(z.where("w.a".attr === "z.a".attr)))
+ .analyze
+ val answer = w
+ .where(Exists(z.where("w.a".attr === "z.a".attr)))
+ .join(x, Inner, Option("w.a".attr === "x.a".attr))
+ .join(y, LeftOuter, Option("x.a".attr === "y.a".attr))
+ .analyze
+ val optimized = Optimize.execute(Optimize.execute(query))
+ comparePlans(optimized, answer)
+ }
+
test("Window: predicate push down -- basic") {
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index 5af3ea9c7a..e73592c7af 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -146,7 +146,7 @@ class ExpressionParserSuite extends PlanTest {
test("in sub-query") {
assertEqual(
"a in (select b from c)",
- InSubQuery('a, table("c").select('b)))
+ In('a, Seq(ListQuery(table("c").select('b)))))
}
test("like expressions") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index f5439d70ad..6310f0c2bc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -34,11 +34,13 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
protected def normalizeExprIds(plan: LogicalPlan) = {
plan transformAllExpressions {
case s: ScalarSubquery =>
- ScalarSubquery(s.query, ExprId(0))
- case s: InSubQuery =>
- InSubQuery(s.value, s.query, ExprId(0))
+ s.copy(exprId = ExprId(0))
case e: Exists =>
- Exists(e.query, ExprId(0))
+ e.copy(exprId = ExprId(0))
+ case l: ListQuery =>
+ l.copy(exprId = ExprId(0))
+ case p: PredicateSubquery =>
+ p.copy(exprId = ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>