aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattavut Sutyanyong <nsy.can@gmail.com>2017-03-14 10:37:10 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2017-03-14 10:37:10 +0100
commit4ce970d71488c7de6025ef925f75b8b92a5a6a79 (patch)
tree2857e3a5c373359042796ca662769b786bc83fbf
parentf6314eab4b494bd5b5e9e41c6f582d4f22c0967a (diff)
downloadspark-4ce970d71488c7de6025ef925f75b8b92a5a6a79.tar.gz
spark-4ce970d71488c7de6025ef925f75b8b92a5a6a79.tar.bz2
spark-4ce970d71488c7de6025ef925f75b8b92a5a6a79.zip
[SPARK-18874][SQL] First phase: Deferring the correlated predicate pull up to Optimizer phase
## What changes were proposed in this pull request? Currently Analyzer as part of ResolveSubquery, pulls up the correlated predicates to its originating SubqueryExpression. The subquery plan is then transformed to remove the correlated predicates after they are moved up to the outer plan. In this PR, the task of pulling up correlated predicates is deferred to Optimizer. This is the initial work that will allow us to support the form of correlated subqueries that we don't support today. The design document from nsyca can be found in the following link : [DesignDoc](https://docs.google.com/document/d/1QDZ8JwU63RwGFS6KVF54Rjj9ZJyK33d49ZWbjFBaIgU/edit#) The brief description of code changes (hopefully to aid with code review) can be be found in the following link: [CodeChanges](https://docs.google.com/document/d/18mqjhL9V1An-tNta7aVE13HkALRZ5GZ24AATA-Vqqf0/edit#) ## How was this patch tested? The test case PRs were submitted earlier using. [16337](https://github.com/apache/spark/pull/16337) [16759](https://github.com/apache/spark/pull/16759) [16841](https://github.com/apache/spark/pull/16841) [16915](https://github.com/apache/spark/pull/16915) [16798](https://github.com/apache/spark/pull/16798) [16712](https://github.com/apache/spark/pull/16712) [16710](https://github.com/apache/spark/pull/16710) [16760](https://github.com/apache/spark/pull/16760) [16802](https://github.com/apache/spark/pull/16802) Author: Dilip Biswal <dbiswal@us.ibm.com> Closes #16954 from dilipbiswal/SPARK-18874.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala314
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala130
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala256
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala159
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala3
-rw-r--r--sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala7
13 files changed, 675 insertions, 300 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 93666f1495..a3764d8c84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -21,12 +21,13 @@ import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf, TableIdentifier}
+import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
+import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
@@ -162,6 +163,8 @@ class Analyzer(
FixNullability),
Batch("ResolveTimeZone", Once,
ResolveTimeZone),
+ Batch("Subquery", Once,
+ UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)
@@ -710,13 +713,72 @@ class Analyzer(
} transformUp {
case other => other transformExpressions {
case a: Attribute =>
- attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier)
+ dedupAttr(a, attributeRewrites)
+ case s: SubqueryExpression =>
+ s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
}
}
newRight
}
}
+ private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
+ attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier)
+ }
+
+ /**
+ * The outer plan may have been de-duplicated and the function below updates the
+ * outer references to refer to the de-duplicated attributes.
+ *
+ * For example (SQL):
+ * {{{
+ * SELECT * FROM t1
+ * INTERSECT
+ * SELECT * FROM t1
+ * WHERE EXISTS (SELECT 1
+ * FROM t2
+ * WHERE t1.c1 = t2.c1)
+ * }}}
+ * Plan before resolveReference rule.
+ * 'Intersect
+ * :- Project [c1#245, c2#246]
+ * : +- SubqueryAlias t1
+ * : +- Relation[c1#245,c2#246] parquet
+ * +- 'Project [*]
+ * +- Filter exists#257 [c1#245]
+ * : +- Project [1 AS 1#258]
+ * : +- Filter (outer(c1#245) = c1#251)
+ * : +- SubqueryAlias t2
+ * : +- Relation[c1#251,c2#252] parquet
+ * +- SubqueryAlias t1
+ * +- Relation[c1#245,c2#246] parquet
+ * Plan after the resolveReference rule.
+ * Intersect
+ * :- Project [c1#245, c2#246]
+ * : +- SubqueryAlias t1
+ * : +- Relation[c1#245,c2#246] parquet
+ * +- Project [c1#259, c2#260]
+ * +- Filter exists#257 [c1#259]
+ * : +- Project [1 AS 1#258]
+ * : +- Filter (outer(c1#259) = c1#251) => Updated
+ * : +- SubqueryAlias t2
+ * : +- Relation[c1#251,c2#252] parquet
+ * +- SubqueryAlias t1
+ * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated.
+ */
+ private def dedupOuterReferencesInSubquery(
+ plan: LogicalPlan,
+ attrMap: AttributeMap[Attribute]): LogicalPlan = {
+ plan transformDown { case currentFragment =>
+ currentFragment transformExpressions {
+ case OuterReference(a: Attribute) =>
+ OuterReference(dedupAttr(a, attrMap))
+ case s: SubqueryExpression =>
+ s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap))
+ }
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
@@ -1132,28 +1194,21 @@ class Analyzer(
}
/**
- * Pull out all (outer) correlated predicates from a given subquery. This method removes the
- * correlated predicates from subquery [[Filter]]s and adds the references of these predicates
- * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to
- * be able to evaluate the predicates at the top level.
- *
- * This method returns the rewritten subquery and correlated predicates.
+ * Validates to make sure the outer references appearing inside the subquery
+ * are legal. This function also returns the list of expressions
+ * that contain outer references. These outer references would be kept as children
+ * of subquery expressions by the caller of this function.
*/
- private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
- val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
+ private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = {
+ val outerReferences = ArrayBuffer.empty[Expression]
// Make sure a plan's subtree does not contain outer references
def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = {
- if (p.collectFirst(predicateMap).nonEmpty) {
+ if (hasOuterReferences(p)) {
failAnalysis(s"Accessing outer query column is not allowed in:\n$p")
}
}
- // Helper function for locating outer references.
- def containsOuter(e: Expression): Boolean = {
- e.find(_.isInstanceOf[OuterReference]).isDefined
- }
-
// Make sure a plan's expressions do not contain outer references
def failOnOuterReference(p: LogicalPlan): Unit = {
if (p.expressions.exists(containsOuter)) {
@@ -1194,20 +1249,11 @@ class Analyzer(
}
}
- /** Determine which correlated predicate references are missing from this plan. */
- def missingReferences(p: LogicalPlan): AttributeSet = {
- val localPredicateReferences = p.collect(predicateMap)
- .flatten
- .map(_.references)
- .reduceOption(_ ++ _)
- .getOrElse(AttributeSet.empty)
- localPredicateReferences -- p.outputSet
- }
-
var foundNonEqualCorrelatedPred : Boolean = false
- // Simplify the predicates before pulling them out.
- val transformed = BooleanSimplification(sub) transformUp {
+ // Simplify the predicates before validating any unsupported correlation patterns
+ // in the plan.
+ BooleanSimplification(sub).foreachUp {
// Whitelist operators allowed in a correlated subquery
// There are 4 categories:
@@ -1229,80 +1275,48 @@ class Analyzer(
// Category 1:
// BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias
- case p: BroadcastHint =>
- p
- case p: Distinct =>
- p
- case p: LeafNode =>
- p
- case p: Repartition =>
- p
- case p: SubqueryAlias =>
- p
+ case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias =>
// Category 2:
// These operators can be anywhere in a correlated subquery.
// so long as they do not host outer references in the operators.
- case p: Sort =>
- failOnOuterReference(p)
- p
- case p: RepartitionByExpression =>
- failOnOuterReference(p)
- p
+ case s: Sort =>
+ failOnOuterReference(s)
+ case r: RepartitionByExpression =>
+ failOnOuterReference(r)
// Category 3:
// Filter is one of the two operators allowed to host correlated expressions.
// The other operator is Join. Filter can be anywhere in a correlated subquery.
- case f @ Filter(cond, child) =>
+ case f: Filter =>
// Find all predicates with an outer reference.
- val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter)
+ val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)
// Find any non-equality correlated predicates
foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
case _: EqualTo | _: EqualNullSafe => false
case _ => true
}
-
- // Rewrite the filter without the correlated predicates if any.
- correlated match {
- case Nil => f
- case xs if local.nonEmpty =>
- val newFilter = Filter(local.reduce(And), child)
- predicateMap += newFilter -> xs
- newFilter
- case xs =>
- predicateMap += child -> xs
- child
- }
+ // The aggregate expressions are treated in a special way by getOuterReferences. If the
+ // aggregate expression contains only outer reference attributes then the entire aggregate
+ // expression is isolated as an OuterReference.
+ // i.e min(OuterReference(b)) => OuterReference(min(b))
+ outerReferences ++= getOuterReferences(correlated)
// Project cannot host any correlated expressions
// but can be anywhere in a correlated subquery.
- case p @ Project(expressions, child) =>
+ case p: Project =>
failOnOuterReference(p)
- val referencesToAdd = missingReferences(p)
- if (referencesToAdd.nonEmpty) {
- Project(expressions ++ referencesToAdd, child)
- } else {
- p
- }
-
// Aggregate cannot host any correlated expressions
// It can be on a correlation path if the correlation contains
// only equality correlated predicates.
// It cannot be on a correlation path if the correlation has
// non-equality correlated predicates.
- case a @ Aggregate(grouping, expressions, child) =>
+ case a: Aggregate =>
failOnOuterReference(a)
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
- val referencesToAdd = missingReferences(a)
- if (referencesToAdd.nonEmpty) {
- Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
- } else {
- a
- }
-
// Join can host correlated expressions.
case j @ Join(left, right, joinType, _) =>
joinType match {
@@ -1332,7 +1346,6 @@ class Analyzer(
case _ =>
failOnOuterReferenceInSubTree(j)
}
- j
// Generator with join=true, i.e., expressed with
// LATERAL VIEW [OUTER], similar to inner join,
@@ -1340,9 +1353,8 @@ class Analyzer(
// but must not host any outer references.
// Note:
// Generator with join=false is treated as Category 4.
- case p @ Generate(generator, true, _, _, _, _) =>
- failOnOuterReference(p)
- p
+ case g: Generate if g.join =>
+ failOnOuterReference(g)
// Category 4: Any other operators not in the above 3 categories
// cannot be on a correlation path, that is they are allowed only
@@ -1350,54 +1362,17 @@ class Analyzer(
// are not allowed to have any correlated expressions.
case p =>
failOnOuterReferenceInSubTree(p)
- p
}
- (transformed, predicateMap.values.flatten.toSeq)
+ outerReferences
}
/**
- * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same
- * attributes.
- */
- private def rewriteSubQuery(
- sub: LogicalPlan,
- outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
- // Pull out the tagged predicates and rewrite the subquery in the process.
- val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub)
-
- // Make sure the inner and the outer query attributes do not collide.
- val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
- val duplicates = basePlan.outputSet.intersect(outputSet)
- val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) {
- val aliasMap = AttributeMap(duplicates.map { dup =>
- dup -> Alias(dup, dup.toString)()
- }.toSeq)
- val aliasedExpressions = basePlan.output.map { ref =>
- aliasMap.getOrElse(ref, ref)
- }
- val aliasedProjection = Project(aliasedExpressions, basePlan)
- val aliasedConditions = baseConditions.map(_.transform {
- case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
- })
- (aliasedProjection, aliasedConditions)
- } else {
- (basePlan, baseConditions)
- }
- // Remove outer references from the correlated predicates. We wait with extracting
- // these until collisions between the inner and outer query attributes have been
- // solved.
- val conditions = deDuplicatedConditions.map(_.transform {
- case OuterReference(ref) => ref
- })
- (plan, conditions)
- }
-
- /**
- * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method
+ * Resolves the subquery. The subquery is resolved using its outer plans. This method
* will resolve the subquery by alternating between the regular analyzer and by applying the
* resolveOuterReferences rule.
*
- * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved.
+ * Outer references from the correlated predicates are updated as children of
+ * Subquery expression.
*/
private def resolveSubQuery(
e: SubqueryExpression,
@@ -1420,7 +1395,8 @@ class Analyzer(
}
} while (!current.resolved && !current.fastEquals(previous))
- // Step 2: Pull out the predicates if the plan is resolved.
+ // Step 2: If the subquery plan is fully resolved, pull the outer references and record
+ // them as children of SubqueryExpression.
if (current.resolved) {
// Make sure the resolved query has the required number of output columns. This is only
// needed for Scalar and IN subqueries.
@@ -1428,34 +1404,37 @@ class Analyzer(
failAnalysis(s"The number of columns in the subquery (${current.output.size}) " +
s"does not match the required number of columns ($requiredColumns)")
}
- // Pullout predicates and construct a new plan.
- f.tupled(rewriteSubQuery(current, plans))
+ // Validate the outer reference and record the outer references as children of
+ // subquery expression.
+ f(current, checkAndGetOuterReferences(current))
} else {
e.withNewPlan(current)
}
}
/**
- * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS
- * expressions into PredicateSubquery expression once the are resolved.
+ * Resolves the subquery. Apart of resolving the subquery and outer references (if any)
+ * in the subquery plan, the children of subquery expression are updated to record the
+ * outer references. This is needed to make sure
+ * (1) The column(s) referred from the outer query are not pruned from the plan during
+ * optimization.
+ * (2) Any aggregate expression(s) that reference outer attributes are pushed down to
+ * outer plan to get evaluated.
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan transformExpressions {
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
- case e @ Exists(sub, exprId) =>
- resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId))
- case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
+ case e @ Exists(sub, _, exprId) if !sub.resolved =>
+ resolveSubQuery(e, plans)(Exists(_, _, exprId))
+ case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved =>
// Get the left hand side expressions.
- val expressions = e match {
+ val expressions = value match {
case cns : CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
- resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) =>
- // Construct the IN conditions.
- val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled)
- PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId)
- }
+ val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId))
+ In(value, Seq(expr))
}
}
@@ -2353,6 +2332,11 @@ class Analyzer(
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
+ // Casts could be added in the subquery plan through the rule TypeCoercion while coercing
+ // the types between the value expression and list query expression of IN expression.
+ // We need to subject the subquery plan through ResolveTimeZone again to setup timezone
+ // information for time zone aware expressions.
+ case e: ListQuery => e.withNewPlan(apply(e.plan))
}
}
}
@@ -2533,3 +2517,67 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
CreateNamedStruct(children.toList)
}
}
+
+/**
+ * The aggregate expressions from subquery referencing outer query block are pushed
+ * down to the outer query block for evaluation. This rule below updates such outer references
+ * as AttributeReference referring attributes from the parent/outer query block.
+ *
+ * For example (SQL):
+ * {{{
+ * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b))
+ * }}}
+ * Plan before the rule.
+ * Project [a#226]
+ * +- Filter exists#245 [min(b#227)#249]
+ * : +- Project [1 AS 1#247]
+ * : +- Filter (d#238 < min(outer(b#227))) <-----
+ * : +- SubqueryAlias r
+ * : +- Project [_1#234 AS c#237, _2#235 AS d#238]
+ * : +- LocalRelation [_1#234, _2#235]
+ * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249]
+ * +- SubqueryAlias l
+ * +- Project [_1#223 AS a#226, _2#224 AS b#227]
+ * +- LocalRelation [_1#223, _2#224]
+ * Plan after the rule.
+ * Project [a#226]
+ * +- Filter exists#245 [min(b#227)#249]
+ * : +- Project [1 AS 1#247]
+ * : +- Filter (d#238 < outer(min(b#227)#249)) <-----
+ * : +- SubqueryAlias r
+ * : +- Project [_1#234 AS c#237, _2#235 AS d#238]
+ * : +- LocalRelation [_1#234, _2#235]
+ * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249]
+ * +- SubqueryAlias l
+ * +- Project [_1#223 AS a#226, _2#224 AS b#227]
+ * +- LocalRelation [_1#223, _2#224]
+ */
+object UpdateOuterReferences extends Rule[LogicalPlan] {
+ private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child }
+
+ private def updateOuterReferenceInSubquery(
+ plan: LogicalPlan,
+ refExprs: Seq[Expression]): LogicalPlan = {
+ plan transformAllExpressions { case e =>
+ val outerAlias =
+ refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e)))
+ outerAlias match {
+ case Some(a: Alias) => OuterReference(a.toAttribute)
+ case _ => e
+ }
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan transform {
+ case f @ Filter(_, a: Aggregate) if f.resolved =>
+ f transformExpressions {
+ case s: SubqueryExpression if s.children.nonEmpty =>
+ // Collect the aliases from output of aggregate.
+ val outerAliases = a.aggregateExpressions collect { case a: Alias => a }
+ // Update the subquery plan to record the OuterReference to point to outer query plan.
+ s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases))
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index d32fbeb4e9..da0c6b098f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -133,10 +134,8 @@ trait CheckAnalysis extends PredicateHelper {
if (conditions.isEmpty && query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
- } else if (conditions.nonEmpty) {
- // Collect the columns from the subquery for further checking.
- var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)
-
+ }
+ else if (conditions.nonEmpty) {
def checkAggregate(agg: Aggregate): Unit = {
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates containing exactly one aggregate expression.
@@ -152,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper {
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
+ // Collect the local references from the correlated predicate in the subquery.
+ val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
+ .filterNot(conditions.flatMap(_.references).contains)
val correlatedCols = AttributeSet(subqueryColumns)
val invalidCols = groupByCols -- correlatedCols
// GROUP BY columns must be a subset of columns in the predicates
@@ -167,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper {
// For projects, do the necessary mapping and skip to its child.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQuery(s.child)
- case p: Project =>
- // SPARK-18814: Map any aliases to their AttributeReference children
- // for the checking in the Aggregate operators below this Project.
- subqueryColumns = subqueryColumns.map {
- xs => p.projectList.collectFirst {
- case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
- child
- }.getOrElse(xs)
- }
-
- cleanQuery(p.child)
+ case p: Project => cleanQuery(p.child)
case child => child
}
@@ -211,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper {
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
- case Filter(condition, _) =>
- splitConjunctivePredicates(condition).foreach {
- case _: PredicateSubquery | Not(_: PredicateSubquery) =>
- case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
- failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" +
- s" conditions: $e")
- case e =>
- }
+ case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) =>
+ failAnalysis("Null-aware predicate sub-queries cannot be used in nested " +
+ s"conditions: $condition")
case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
@@ -306,8 +293,11 @@ trait CheckAnalysis extends PredicateHelper {
s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p")
}
- case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
- failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
+ case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) =>
+ p match {
+ case _: Filter => // Ok
+ case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
+ }
case _: Union | _: SetOperation if operator.children.length > 1 =>
def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 2c00957bd6..768897dc07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -109,6 +109,28 @@ object TypeCoercion {
}
/**
+ * This function determines the target type of a comparison operator when one operand
+ * is a String and the other is not. It also handles when one op is a Date and the
+ * other is a Timestamp by making the target type to be String.
+ */
+ val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
+ // We should cast all relative timestamp/date/string comparison into string comparisons
+ // This behaves as a user would expect because timestamp strings sort lexicographically.
+ // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
+ case (StringType, DateType) => Some(StringType)
+ case (DateType, StringType) => Some(StringType)
+ case (StringType, TimestampType) => Some(StringType)
+ case (TimestampType, StringType) => Some(StringType)
+ case (TimestampType, DateType) => Some(StringType)
+ case (DateType, TimestampType) => Some(StringType)
+ case (StringType, NullType) => Some(StringType)
+ case (NullType, StringType) => Some(StringType)
+ case (l: StringType, r: AtomicType) if r != StringType => Some(r)
+ case (l: AtomicType, r: StringType) if (l != StringType) => Some(l)
+ case (l, r) => None
+ }
+
+ /**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
* i.e. the main difference with [[findTightestCommonType]] is that here we allow some
@@ -305,6 +327,14 @@ object TypeCoercion {
* Promotes strings that appear in arithmetic expressions.
*/
object PromoteStrings extends Rule[LogicalPlan] {
+ private def castExpr(expr: Expression, targetType: DataType): Expression = {
+ (expr.dataType, targetType) match {
+ case (NullType, dt) => Literal.create(null, targetType)
+ case (l, dt) if (l != dt) => Cast(expr, targetType)
+ case _ => expr
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -321,37 +351,10 @@ object TypeCoercion {
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))
- // We should cast all relative timestamp/date/string comparison into string comparisons
- // This behaves as a user would expect because timestamp strings sort lexicographically.
- // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
- case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
- p.makeCopy(Array(left, Cast(right, StringType)))
- case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
- p.makeCopy(Array(Cast(left, StringType), right))
- case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
- p.makeCopy(Array(left, Cast(right, StringType)))
- case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
- p.makeCopy(Array(Cast(left, StringType), right))
-
- // Comparisons between dates and timestamps.
- case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
- p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
- case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
- p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
-
- // Checking NullType
- case p @ BinaryComparison(left @ StringType(), right @ NullType()) =>
- p.makeCopy(Array(left, Literal.create(null, StringType)))
- case p @ BinaryComparison(left @ NullType(), right @ StringType()) =>
- p.makeCopy(Array(Literal.create(null, StringType), right))
-
- // When compare string with atomic type, case string to that type.
- case p @ BinaryComparison(left @ StringType(), right @ AtomicType())
- if right.dataType != StringType =>
- p.makeCopy(Array(Cast(left, right.dataType), right))
- case p @ BinaryComparison(left @ AtomicType(), right @ StringType())
- if left.dataType != StringType =>
- p.makeCopy(Array(left, Cast(right, left.dataType)))
+ case p @ BinaryComparison(left, right)
+ if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
+ val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
+ p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
@@ -365,17 +368,72 @@ object TypeCoercion {
}
/**
- * Convert the value and in list expressions to the common operator type
- * by looking at all the argument types and finding the closest one that
- * all the arguments can be cast to. When no common operator type is found
- * the original expression will be returned and an Analysis Exception will
- * be raised at type checking phase.
+ * Handles type coercion for both IN expression with subquery and IN
+ * expressions without subquery.
+ * 1. In the first case, find the common type by comparing the left hand side (LHS)
+ * expression types against corresponding right hand side (RHS) expression derived
+ * from the subquery expression's plan output. Inject appropriate casts in the
+ * LHS and RHS side of IN expression.
+ *
+ * 2. In the second case, convert the value and in list expressions to the
+ * common operator type by looking at all the argument types and finding
+ * the closest one that all the arguments can be cast to. When no common
+ * operator type is found the original expression will be returned and an
+ * Analysis Exception will be raised at the type checking phase.
*/
object InConversion extends Rule[LogicalPlan] {
+ private def flattenExpr(expr: Expression): Seq[Expression] = {
+ expr match {
+ // Multi columns in IN clause is represented as a CreateNamedStruct.
+ // flatten the named struct to get the list of expressions.
+ case cns: CreateNamedStruct => cns.valExprs
+ case expr => Seq(expr)
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
+ // Handle type casting required between value expression and subquery output
+ // in IN subquery.
+ case i @ In(a, Seq(ListQuery(sub, children, exprId)))
+ if !i.resolved && flattenExpr(a).length == sub.output.length =>
+ // LHS is the value expression of IN subquery.
+ val lhs = flattenExpr(a)
+
+ // RHS is the subquery output.
+ val rhs = sub.output
+
+ val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
+ findCommonTypeForBinaryComparison(l.dataType, r.dataType)
+ .orElse(findTightestCommonType(l.dataType, r.dataType))
+ }
+
+ // The number of columns/expressions must match between LHS and RHS of an
+ // IN subquery expression.
+ if (commonTypes.length == lhs.length) {
+ val castedRhs = rhs.zip(commonTypes).map {
+ case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
+ case (e, _) => e
+ }
+ val castedLhs = lhs.zip(commonTypes).map {
+ case (e, dt) if e.dataType != dt => Cast(e, dt)
+ case (e, _) => e
+ }
+
+ // Before constructing the In expression, wrap the multi values in LHS
+ // in a CreatedNamedStruct.
+ val newLhs = castedLhs match {
+ case Seq(lhs) => lhs
+ case _ => CreateStruct(castedLhs)
+ }
+
+ In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
+ } else {
+ i
+ }
+
case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index ac56ff13fa..e5d1a1e299 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -123,19 +123,44 @@ case class Not(child: Expression)
*/
@ExpressionDescription(
usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.")
-case class In(value: Expression, list: Seq[Expression]) extends Predicate
- with ImplicitCastInputTypes {
+case class In(value: Expression, list: Seq[Expression]) extends Predicate {
require(list != null, "list should not be null")
+ override def checkInputDataTypes(): TypeCheckResult = {
+ list match {
+ case ListQuery(sub, _, _) :: Nil =>
+ val valExprs = value match {
+ case cns: CreateNamedStruct => cns.valExprs
+ case expr => Seq(expr)
+ }
- override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
+ val mismatchedColumns = valExprs.zip(sub.output).flatMap {
+ case (l, r) if l.dataType != r.dataType =>
+ s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
+ case _ => None
+ }
- override def checkInputDataTypes(): TypeCheckResult = {
- if (list.exists(l => l.dataType != value.dataType)) {
- TypeCheckResult.TypeCheckFailure(
- "Arguments must be same type")
- } else {
- TypeCheckResult.TypeCheckSuccess
+ if (mismatchedColumns.nonEmpty) {
+ TypeCheckResult.TypeCheckFailure(
+ s"""
+ |The data type of one or more elements in the left hand side of an IN subquery
+ |is not compatible with the data type of the output of the subquery
+ |Mismatched columns:
+ |[${mismatchedColumns.mkString(", ")}]
+ |Left side:
+ |[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
+ |Right side:
+ |[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
+ """.stripMargin)
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ case _ =>
+ if (list.exists(l => l.dataType != value.dataType)) {
+ TypeCheckResult.TypeCheckFailure("Arguments must be same type")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index e2e7d98e33..ad11700fa2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.types._
/**
@@ -40,19 +43,184 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
/**
* A base interface for expressions that contain a [[LogicalPlan]].
*/
-abstract class SubqueryExpression extends PlanExpression[LogicalPlan] {
+abstract class SubqueryExpression(
+ plan: LogicalPlan,
+ children: Seq[Expression],
+ exprId: ExprId) extends PlanExpression[LogicalPlan] {
+
+ override lazy val resolved: Boolean = childrenResolved && plan.resolved
+ override lazy val references: AttributeSet =
+ if (plan.resolved) super.references -- plan.outputSet else super.references
override def withNewPlan(plan: LogicalPlan): SubqueryExpression
+ override def semanticEquals(o: Expression): Boolean = o match {
+ case p: SubqueryExpression =>
+ this.getClass.getName.equals(p.getClass.getName) && plan.sameResult(p.plan) &&
+ children.length == p.children.length &&
+ children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
+ case _ => false
+ }
}
object SubqueryExpression {
+ /**
+ * Returns true when an expression contains an IN or EXISTS subquery and false otherwise.
+ */
+ def hasInOrExistsSubquery(e: Expression): Boolean = {
+ e.find {
+ case _: ListQuery | _: Exists => true
+ case _ => false
+ }.isDefined
+ }
+
+ /**
+ * Returns true when an expression contains a subquery that has outer reference(s). The outer
+ * reference attributes are kept as children of subquery expression by
+ * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]]
+ */
def hasCorrelatedSubquery(e: Expression): Boolean = {
e.find {
- case e: SubqueryExpression if e.children.nonEmpty => true
+ case s: SubqueryExpression => s.children.nonEmpty
case _ => false
}.isDefined
}
}
+object SubExprUtils extends PredicateHelper {
+ /**
+ * Returns true when an expression contains correlated predicates i.e outer references and
+ * returns false otherwise.
+ */
+ def containsOuter(e: Expression): Boolean = {
+ e.find(_.isInstanceOf[OuterReference]).isDefined
+ }
+
+ /**
+ * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could
+ * turn the null-aware predicate into not-null-aware predicate.
+ */
+ def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = {
+ splitConjunctivePredicates(condition).exists {
+ case _: Exists | Not(_: Exists) | In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) =>
+ false
+ case e => e.find { x =>
+ x.isInstanceOf[Not] && e.find {
+ case In(_, Seq(_: ListQuery)) => true
+ case _ => false
+ }.isDefined
+ }.isDefined
+ }
+
+ }
+
+ /**
+ * Returns an expression after removing the OuterReference shell.
+ */
+ def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r }
+
+ /**
+ * Returns the list of expressions after removing the OuterReference shell from each of
+ * the expression.
+ */
+ def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference)
+
+ /**
+ * Returns the logical plan after removing the OuterReference shell from all the expressions
+ * of the input logical plan.
+ */
+ def stripOuterReferences(p: LogicalPlan): LogicalPlan = {
+ p.transformAllExpressions {
+ case OuterReference(a) => a
+ }
+ }
+
+ /**
+ * Given a logical plan, returns TRUE if it has an outer reference and false otherwise.
+ */
+ def hasOuterReferences(plan: LogicalPlan): Boolean = {
+ plan.find {
+ case f: Filter => containsOuter(f.condition)
+ case other => false
+ }.isDefined
+ }
+
+ /**
+ * Given a list of expressions, returns the expressions which have outer references. Aggregate
+ * expressions are treated in a special way. If the children of aggregate expression contains an
+ * outer reference, then the entire aggregate expression is marked as an outer reference.
+ * Example (SQL):
+ * {{{
+ * SELECT a FROM l GROUP by 1 HAVING EXISTS (SELECT 1 FROM r WHERE d < min(b))
+ * }}}
+ * In the above case, we want to mark the entire min(b) as an outer reference
+ * OuterReference(min(b)) instead of min(OuterReference(b)).
+ * TODO: Currently we don't allow deep correlation. Also, we don't allow mixing of
+ * outer references and local references under an aggregate expression.
+ * For example (SQL):
+ * {{{
+ * SELECT .. FROM p1
+ * WHERE EXISTS (SELECT ...
+ * FROM p2
+ * WHERE EXISTS (SELECT ...
+ * FROM sq
+ * WHERE min(p1.a + p2.b) = sq.c))
+ *
+ * SELECT .. FROM p1
+ * WHERE EXISTS (SELECT ...
+ * FROM p2
+ * WHERE EXISTS (SELECT ...
+ * FROM sq
+ * WHERE min(p1.a) + max(p2.b) = sq.c))
+ *
+ * SELECT .. FROM p1
+ * WHERE EXISTS (SELECT ...
+ * FROM p2
+ * WHERE EXISTS (SELECT ...
+ * FROM sq
+ * WHERE min(p1.a + sq.c) > 1))
+ * }}}
+ * The code below needs to change when we support the above cases.
+ */
+ def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = {
+ val outerExpressions = ArrayBuffer.empty[Expression]
+ conditions foreach { expr =>
+ expr transformDown {
+ case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) =>
+ val newExpr = stripOuterReference(a)
+ outerExpressions += newExpr
+ newExpr
+ case OuterReference(e) =>
+ outerExpressions += e
+ e
+ }
+ }
+ outerExpressions
+ }
+
+ /**
+ * Returns all the expressions that have outer references from a logical plan. Currently only
+ * Filter operator can host outer references.
+ */
+ def getOuterReferences(plan: LogicalPlan): Seq[Expression] = {
+ val conditions = plan.collect { case Filter(cond, _) => cond }
+ getOuterReferences(conditions)
+ }
+
+ /**
+ * Returns the correlated predicates from a logical plan. The OuterReference wrapper
+ * is removed before returning the predicate to the caller.
+ */
+ def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = {
+ val conditions = plan.collect { case Filter(cond, _) => cond }
+ conditions.flatMap { e =>
+ val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter)
+ stripOuterReferences(correlated) match {
+ case Nil => None
+ case xs => xs
+ }
+ }
+ }
+}
+
/**
* A subquery that will return only one row and one column. This will be converted into a physical
* scalar subquery during planning.
@@ -63,14 +231,8 @@ case class ScalarSubquery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
- extends SubqueryExpression with Unevaluable {
- override lazy val resolved: Boolean = childrenResolved && plan.resolved
- override lazy val references: AttributeSet = {
- if (plan.resolved) super.references -- plan.outputSet
- else super.references
- }
+ extends SubqueryExpression(plan, children, exprId) with Unevaluable {
override def dataType: DataType = plan.schema.fields.head.dataType
- override def foldable: Boolean = false
override def nullable: Boolean = true
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
@@ -79,60 +241,13 @@ case class ScalarSubquery(
object ScalarSubquery {
def hasCorrelatedScalarSubquery(e: Expression): Boolean = {
e.find {
- case e: ScalarSubquery if e.children.nonEmpty => true
+ case s: ScalarSubquery => s.children.nonEmpty
case _ => false
}.isDefined
}
}
/**
- * A predicate subquery checks the existence of a value in a sub-query. We currently only allow
- * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will
- * be rewritten into a left semi/anti join during analysis.
- */
-case class PredicateSubquery(
- plan: LogicalPlan,
- children: Seq[Expression] = Seq.empty,
- nullAware: Boolean = false,
- exprId: ExprId = NamedExpression.newExprId)
- extends SubqueryExpression with Predicate with Unevaluable {
- override lazy val resolved = childrenResolved && plan.resolved
- override lazy val references: AttributeSet = super.references -- plan.outputSet
- override def nullable: Boolean = nullAware
- override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan)
- override def semanticEquals(o: Expression): Boolean = o match {
- case p: PredicateSubquery =>
- plan.sameResult(p.plan) && nullAware == p.nullAware &&
- children.length == p.children.length &&
- children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
- case _ => false
- }
- override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
-}
-
-object PredicateSubquery {
- def hasPredicateSubquery(e: Expression): Boolean = {
- e.find {
- case _: PredicateSubquery | _: ListQuery | _: Exists => true
- case _ => false
- }.isDefined
- }
-
- /**
- * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could
- * turn the null-aware predicate into not-null-aware predicate.
- */
- def hasNullAwarePredicateWithinNot(e: Expression): Boolean = {
- e.find{ x =>
- x.isInstanceOf[Not] && e.find {
- case p: PredicateSubquery => p.nullAware
- case _ => false
- }.isDefined
- }.isDefined
- }
-}
-
-/**
* A [[ListQuery]] expression defines the query which we want to search in an IN subquery
* expression. It should and can only be used in conjunction with an IN expression.
*
@@ -144,18 +259,20 @@ object PredicateSubquery {
* FROM b)
* }}}
*/
-case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
- extends SubqueryExpression with Unevaluable {
- override lazy val resolved = false
- override def children: Seq[Expression] = Seq.empty
- override def dataType: DataType = ArrayType(NullType)
+case class ListQuery(
+ plan: LogicalPlan,
+ children: Seq[Expression] = Seq.empty,
+ exprId: ExprId = NamedExpression.newExprId)
+ extends SubqueryExpression(plan, children, exprId) with Unevaluable {
+ override def dataType: DataType = plan.schema.fields.head.dataType
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
- override def toString: String = s"list#${exprId.id}"
+ override def toString: String = s"list#${exprId.id} $conditionString"
}
/**
* The [[Exists]] expression checks if a row exists in a subquery given some correlated condition.
+ *
* For example (SQL):
* {{{
* SELECT *
@@ -165,11 +282,12 @@ case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExpr
* WHERE b.id = a.id)
* }}}
*/
-case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
- extends SubqueryExpression with Predicate with Unevaluable {
- override lazy val resolved = false
- override def children: Seq[Expression] = Seq.empty
+case class Exists(
+ plan: LogicalPlan,
+ children: Seq[Expression] = Seq.empty,
+ exprId: ExprId = NamedExpression.newExprId)
+ extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable {
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
- override def toString: String = s"exists#${exprId.id}"
+ override def toString: String = s"exists#${exprId.id} $conditionString"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index caafa1c134..e9dbded3d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -68,6 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
// since the other rules might make two separate Unions operators adjacent.
Batch("Union", Once,
CombineUnions) ::
+ Batch("Pullup Correlated Expressions", Once,
+ PullupCorrelatedPredicates) ::
Batch("Subquery", Once,
OptimizeSubqueries) ::
Batch("Replace Operators", fixedPoint,
@@ -885,7 +887,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = {
val attributes = plan.outputSet
val matched = condition.find {
- case PredicateSubquery(p, _, _, _) => p.outputSet.intersect(attributes).nonEmpty
+ case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty
case _ => false
}
matched.isEmpty
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index fb7ce6aece..ba3fd1d5f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -41,10 +42,17 @@ import org.apache.spark.sql.types._
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
+ private def getValueExpression(e: Expression): Seq[Expression] = {
+ e match {
+ case cns : CreateNamedStruct => cns.valExprs
+ case expr => Seq(expr)
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
- splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery)
+ splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery)
// Construct the pruned filter condition.
val newFilter: LogicalPlan = withoutSubquery match {
@@ -54,20 +62,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
- case (p, PredicateSubquery(sub, conditions, _, _)) =>
+ case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
- case (p, Not(PredicateSubquery(sub, conditions, false, _))) =>
+ case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftAnti, joinCond)
- case (p, Not(PredicateSubquery(sub, conditions, true, _))) =>
+ case (p, In(value, Seq(ListQuery(sub, conditions, _)))) =>
+ val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
+ Join(outerPlan, sub, LeftSemi, joinCond)
+ case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
- val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
+ val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
// (a1,b1,...) = (a2,b2,...)
@@ -83,11 +96,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Given a predicate expression and an input plan, it rewrites
- * any embedded existential sub-query into an existential join.
- * It returns the rewritten expression together with the updated plan.
- * Currently, it does not support null-aware joins. Embedded NOT IN predicates
- * are blocked in the Analyzer.
+ * Given a predicate expression and an input plan, it rewrites any embedded existential sub-query
+ * into an existential join. It returns the rewritten expression together with the updated plan.
+ * Currently, it does not support NOT IN nested inside a NOT expression. This case is blocked in
+ * the Analyzer.
*/
private def rewriteExistentialExpr(
exprs: Seq[Expression],
@@ -95,17 +107,138 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
var newPlan = plan
val newExprs = exprs.map { e =>
e transformUp {
- case PredicateSubquery(sub, conditions, nullAware, _) =>
- // TODO: support null-aware join
+ case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
- }
+ case In(value, Seq(ListQuery(sub, conditions, _))) =>
+ val exists = AttributeReference("exists", BooleanType, nullable = false)()
+ val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+ val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
+ newPlan = Join(newPlan, sub, ExistenceJoin(exists), newConditions)
+ exists
+ }
}
(newExprs.reduceOption(And), newPlan)
}
}
+ /**
+ * Pull out all (outer) correlated predicates from a given subquery. This method removes the
+ * correlated predicates from subquery [[Filter]]s and adds the references of these predicates
+ * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to
+ * be able to evaluate the predicates at the top level.
+ *
+ * TODO: Look to merge this rule with RewritePredicateSubquery.
+ */
+object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper {
+ /**
+ * Returns the correlated predicates and a updated plan that removes the outer references.
+ */
+ private def pullOutCorrelatedPredicates(
+ sub: LogicalPlan,
+ outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
+ val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
+
+ /** Determine which correlated predicate references are missing from this plan. */
+ def missingReferences(p: LogicalPlan): AttributeSet = {
+ val localPredicateReferences = p.collect(predicateMap)
+ .flatten
+ .map(_.references)
+ .reduceOption(_ ++ _)
+ .getOrElse(AttributeSet.empty)
+ localPredicateReferences -- p.outputSet
+ }
+
+ // Simplify the predicates before pulling them out.
+ val transformed = BooleanSimplification(sub) transformUp {
+ case f @ Filter(cond, child) =>
+ val (correlated, local) =
+ splitConjunctivePredicates(cond).partition(containsOuter)
+
+ // Rewrite the filter without the correlated predicates if any.
+ correlated match {
+ case Nil => f
+ case xs if local.nonEmpty =>
+ val newFilter = Filter(local.reduce(And), child)
+ predicateMap += newFilter -> xs
+ newFilter
+ case xs =>
+ predicateMap += child -> xs
+ child
+ }
+ case p @ Project(expressions, child) =>
+ val referencesToAdd = missingReferences(p)
+ if (referencesToAdd.nonEmpty) {
+ Project(expressions ++ referencesToAdd, child)
+ } else {
+ p
+ }
+ case a @ Aggregate(grouping, expressions, child) =>
+ val referencesToAdd = missingReferences(a)
+ if (referencesToAdd.nonEmpty) {
+ Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
+ } else {
+ a
+ }
+ case p =>
+ p
+ }
+
+ // Make sure the inner and the outer query attributes do not collide.
+ // In case of a collision, change the subquery plan's output to use
+ // different attribute by creating alias(s).
+ val baseConditions = predicateMap.values.flatten.toSeq
+ val (newPlan, newCond) = if (outer.nonEmpty) {
+ val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
+ val duplicates = transformed.outputSet.intersect(outputSet)
+ val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) {
+ val aliasMap = AttributeMap(duplicates.map { dup =>
+ dup -> Alias(dup, dup.toString)()
+ }.toSeq)
+ val aliasedExpressions = transformed.output.map { ref =>
+ aliasMap.getOrElse(ref, ref)
+ }
+ val aliasedProjection = Project(aliasedExpressions, transformed)
+ val aliasedConditions = baseConditions.map(_.transform {
+ case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
+ })
+ (aliasedProjection, aliasedConditions)
+ } else {
+ (transformed, baseConditions)
+ }
+ (plan, stripOuterReferences(deDuplicatedConditions))
+ } else {
+ (transformed, stripOuterReferences(baseConditions))
+ }
+ (newPlan, newCond)
+ }
+
+ private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = {
+ plan transformExpressions {
+ case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
+ val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+ ScalarSubquery(newPlan, newCond, exprId)
+ case Exists(sub, children, exprId) if children.nonEmpty =>
+ val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+ Exists(newPlan, newCond, exprId)
+ case ListQuery(sub, _, exprId) =>
+ val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+ ListQuery(newPlan, newCond, exprId)
+ }
+ }
+
+ /**
+ * Pull up the correlated predicates and rewrite all subqueries in an operator tree..
+ */
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case f @ Filter(_, a: Aggregate) =>
+ rewriteSubQueries(f, Seq(a, a.child))
+ // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
+ case q: UnaryNode =>
+ rewriteSubQueries(q, q.children)
+ }
+}
/**
* This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index c5e877d128..d2ebca5a83 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -530,7 +530,7 @@ class AnalysisErrorSuite extends AnalysisTest {
Exists(
Join(
LocalRelation(b),
- Filter(EqualTo(OuterReference(a), c), LocalRelation(c)),
+ Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)),
LeftOuter,
Option(EqualTo(b, c)))),
LocalRelation(a))
@@ -539,7 +539,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan2 = Filter(
Exists(
Join(
- Filter(EqualTo(OuterReference(a), c), LocalRelation(c)),
+ Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)),
LocalRelation(b),
RightOuter,
Option(EqualTo(b, c)))),
@@ -547,14 +547,15 @@ class AnalysisErrorSuite extends AnalysisTest {
assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil)
val plan3 = Filter(
- Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))),
+ Exists(Union(LocalRelation(b),
+ Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)))),
LocalRelation(a))
assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
val plan4 = Filter(
Exists(
Limit(1,
- Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))
+ Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
),
LocalRelation(a))
assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil)
@@ -562,7 +563,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan5 = Filter(
Exists(
Sample(0.0, 0.5, false, 1L,
- Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b)
+ Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b)
),
LocalRelation(a))
assertAnalysisError(plan5,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
index 4aafb2b83f..5569312143 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
@@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest {
val t2 = LocalRelation(b)
test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
- val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1)
+ val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1)
val m = intercept[AnalysisException] {
SimpleAnalyzer.ResolveSubquery(expr)
}.getMessage
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index e9b7a0c6ad..5eb31413ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -43,8 +43,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
e.copy(exprId = ExprId(0))
case l: ListQuery =>
l.copy(exprId = ExprId(0))
- case p: PredicateSubquery =>
- p.copy(exprId = ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 730ca27f82..58be2d1da2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -144,9 +144,6 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
ScalarSubquery(
SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan),
subquery.exprId)
- case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) =>
- val executedPlan = new QueryExecution(sparkSession, query).executedPlan
- InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
}
}
}
diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
index 50ae01e181..f7bbb35aad 100644
--- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
@@ -46,7 +46,7 @@ and t2b = (select max(avg)
struct<>
-- !query 3 output
org.apache.spark.sql.AnalysisException
-expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
+grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.;
-- !query 4
@@ -63,4 +63,4 @@ where t1a in (select min(t2a)
struct<>
-- !query 4 output
org.apache.spark.sql.AnalysisException
-resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)];
+resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 25dbecb589..6f1cd49c08 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -622,7 +622,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") {
checkAnswer(
- sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"),
+ sql(
+ """
+ |select l.b, (select (r.c + count(*)) is null
+ |from r
+ |where l.a = r.c group by r.c) from l
+ """.stripMargin),
Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
}