aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala314
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala130
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala256
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala159
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala3
-rw-r--r--sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala7
13 files changed, 675 insertions, 300 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 93666f1495..a3764d8c84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -21,12 +21,13 @@ import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf, TableIdentifier}
+import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
+import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
@@ -162,6 +163,8 @@ class Analyzer(
FixNullability),
Batch("ResolveTimeZone", Once,
ResolveTimeZone),
+ Batch("Subquery", Once,
+ UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)
@@ -710,13 +713,72 @@ class Analyzer(
} transformUp {
case other => other transformExpressions {
case a: Attribute =>
- attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier)
+ dedupAttr(a, attributeRewrites)
+ case s: SubqueryExpression =>
+ s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
}
}
newRight
}
}
+ private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
+ attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier)
+ }
+
+ /**
+ * The outer plan may have been de-duplicated and the function below updates the
+ * outer references to refer to the de-duplicated attributes.
+ *
+ * For example (SQL):
+ * {{{
+ * SELECT * FROM t1
+ * INTERSECT
+ * SELECT * FROM t1
+ * WHERE EXISTS (SELECT 1
+ * FROM t2
+ * WHERE t1.c1 = t2.c1)
+ * }}}
+ * Plan before resolveReference rule.
+ * 'Intersect
+ * :- Project [c1#245, c2#246]
+ * : +- SubqueryAlias t1
+ * : +- Relation[c1#245,c2#246] parquet
+ * +- 'Project [*]
+ * +- Filter exists#257 [c1#245]
+ * : +- Project [1 AS 1#258]
+ * : +- Filter (outer(c1#245) = c1#251)
+ * : +- SubqueryAlias t2
+ * : +- Relation[c1#251,c2#252] parquet
+ * +- SubqueryAlias t1
+ * +- Relation[c1#245,c2#246] parquet
+ * Plan after the resolveReference rule.
+ * Intersect
+ * :- Project [c1#245, c2#246]
+ * : +- SubqueryAlias t1
+ * : +- Relation[c1#245,c2#246] parquet
+ * +- Project [c1#259, c2#260]
+ * +- Filter exists#257 [c1#259]
+ * : +- Project [1 AS 1#258]
+ * : +- Filter (outer(c1#259) = c1#251) => Updated
+ * : +- SubqueryAlias t2
+ * : +- Relation[c1#251,c2#252] parquet
+ * +- SubqueryAlias t1
+ * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated.
+ */
+ private def dedupOuterReferencesInSubquery(
+ plan: LogicalPlan,
+ attrMap: AttributeMap[Attribute]): LogicalPlan = {
+ plan transformDown { case currentFragment =>
+ currentFragment transformExpressions {
+ case OuterReference(a: Attribute) =>
+ OuterReference(dedupAttr(a, attrMap))
+ case s: SubqueryExpression =>
+ s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap))
+ }
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
@@ -1132,28 +1194,21 @@ class Analyzer(
}
/**
- * Pull out all (outer) correlated predicates from a given subquery. This method removes the
- * correlated predicates from subquery [[Filter]]s and adds the references of these predicates
- * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to
- * be able to evaluate the predicates at the top level.
- *
- * This method returns the rewritten subquery and correlated predicates.
+ * Validates to make sure the outer references appearing inside the subquery
+ * are legal. This function also returns the list of expressions
+ * that contain outer references. These outer references would be kept as children
+ * of subquery expressions by the caller of this function.
*/
- private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
- val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
+ private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = {
+ val outerReferences = ArrayBuffer.empty[Expression]
// Make sure a plan's subtree does not contain outer references
def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = {
- if (p.collectFirst(predicateMap).nonEmpty) {
+ if (hasOuterReferences(p)) {
failAnalysis(s"Accessing outer query column is not allowed in:\n$p")
}
}
- // Helper function for locating outer references.
- def containsOuter(e: Expression): Boolean = {
- e.find(_.isInstanceOf[OuterReference]).isDefined
- }
-
// Make sure a plan's expressions do not contain outer references
def failOnOuterReference(p: LogicalPlan): Unit = {
if (p.expressions.exists(containsOuter)) {
@@ -1194,20 +1249,11 @@ class Analyzer(
}
}
- /** Determine which correlated predicate references are missing from this plan. */
- def missingReferences(p: LogicalPlan): AttributeSet = {
- val localPredicateReferences = p.collect(predicateMap)
- .flatten
- .map(_.references)
- .reduceOption(_ ++ _)
- .getOrElse(AttributeSet.empty)
- localPredicateReferences -- p.outputSet
- }
-
var foundNonEqualCorrelatedPred : Boolean = false
- // Simplify the predicates before pulling them out.
- val transformed = BooleanSimplification(sub) transformUp {
+ // Simplify the predicates before validating any unsupported correlation patterns
+ // in the plan.
+ BooleanSimplification(sub).foreachUp {
// Whitelist operators allowed in a correlated subquery
// There are 4 categories:
@@ -1229,80 +1275,48 @@ class Analyzer(
// Category 1:
// BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias
- case p: BroadcastHint =>
- p
- case p: Distinct =>
- p
- case p: LeafNode =>
- p
- case p: Repartition =>
- p
- case p: SubqueryAlias =>
- p
+ case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias =>
// Category 2:
// These operators can be anywhere in a correlated subquery.
// so long as they do not host outer references in the operators.
- case p: Sort =>
- failOnOuterReference(p)
- p
- case p: RepartitionByExpression =>
- failOnOuterReference(p)
- p
+ case s: Sort =>
+ failOnOuterReference(s)
+ case r: RepartitionByExpression =>
+ failOnOuterReference(r)
// Category 3:
// Filter is one of the two operators allowed to host correlated expressions.
// The other operator is Join. Filter can be anywhere in a correlated subquery.
- case f @ Filter(cond, child) =>
+ case f: Filter =>
// Find all predicates with an outer reference.
- val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter)
+ val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)
// Find any non-equality correlated predicates
foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
case _: EqualTo | _: EqualNullSafe => false
case _ => true
}
-
- // Rewrite the filter without the correlated predicates if any.
- correlated match {
- case Nil => f
- case xs if local.nonEmpty =>
- val newFilter = Filter(local.reduce(And), child)
- predicateMap += newFilter -> xs
- newFilter
- case xs =>
- predicateMap += child -> xs
- child
- }
+ // The aggregate expressions are treated in a special way by getOuterReferences. If the
+ // aggregate expression contains only outer reference attributes then the entire aggregate
+ // expression is isolated as an OuterReference.
+ // i.e min(OuterReference(b)) => OuterReference(min(b))
+ outerReferences ++= getOuterReferences(correlated)
// Project cannot host any correlated expressions
// but can be anywhere in a correlated subquery.
- case p @ Project(expressions, child) =>
+ case p: Project =>
failOnOuterReference(p)
- val referencesToAdd = missingReferences(p)
- if (referencesToAdd.nonEmpty) {
- Project(expressions ++ referencesToAdd, child)
- } else {
- p
- }
-
// Aggregate cannot host any correlated expressions
// It can be on a correlation path if the correlation contains
// only equality correlated predicates.
// It cannot be on a correlation path if the correlation has
// non-equality correlated predicates.
- case a @ Aggregate(grouping, expressions, child) =>
+ case a: Aggregate =>
failOnOuterReference(a)
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
- val referencesToAdd = missingReferences(a)
- if (referencesToAdd.nonEmpty) {
- Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
- } else {
- a
- }
-
// Join can host correlated expressions.
case j @ Join(left, right, joinType, _) =>
joinType match {
@@ -1332,7 +1346,6 @@ class Analyzer(
case _ =>
failOnOuterReferenceInSubTree(j)
}
- j
// Generator with join=true, i.e., expressed with
// LATERAL VIEW [OUTER], similar to inner join,
@@ -1340,9 +1353,8 @@ class Analyzer(
// but must not host any outer references.
// Note:
// Generator with join=false is treated as Category 4.
- case p @ Generate(generator, true, _, _, _, _) =>
- failOnOuterReference(p)
- p
+ case g: Generate if g.join =>
+ failOnOuterReference(g)
// Category 4: Any other operators not in the above 3 categories
// cannot be on a correlation path, that is they are allowed only
@@ -1350,54 +1362,17 @@ class Analyzer(
// are not allowed to have any correlated expressions.
case p =>
failOnOuterReferenceInSubTree(p)
- p
}
- (transformed, predicateMap.values.flatten.toSeq)
+ outerReferences
}
/**
- * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same
- * attributes.
- */
- private def rewriteSubQuery(
- sub: LogicalPlan,
- outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
- // Pull out the tagged predicates and rewrite the subquery in the process.
- val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub)
-
- // Make sure the inner and the outer query attributes do not collide.
- val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
- val duplicates = basePlan.outputSet.intersect(outputSet)
- val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) {
- val aliasMap = AttributeMap(duplicates.map { dup =>
- dup -> Alias(dup, dup.toString)()
- }.toSeq)
- val aliasedExpressions = basePlan.output.map { ref =>
- aliasMap.getOrElse(ref, ref)
- }
- val aliasedProjection = Project(aliasedExpressions, basePlan)
- val aliasedConditions = baseConditions.map(_.transform {
- case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
- })
- (aliasedProjection, aliasedConditions)
- } else {
- (basePlan, baseConditions)
- }
- // Remove outer references from the correlated predicates. We wait with extracting
- // these until collisions between the inner and outer query attributes have been
- // solved.
- val conditions = deDuplicatedConditions.map(_.transform {
- case OuterReference(ref) => ref
- })
- (plan, conditions)
- }
-
- /**
- * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method
+ * Resolves the subquery. The subquery is resolved using its outer plans. This method
* will resolve the subquery by alternating between the regular analyzer and by applying the
* resolveOuterReferences rule.
*
- * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved.
+ * Outer references from the correlated predicates are updated as children of
+ * Subquery expression.
*/
private def resolveSubQuery(
e: SubqueryExpression,
@@ -1420,7 +1395,8 @@ class Analyzer(
}
} while (!current.resolved && !current.fastEquals(previous))
- // Step 2: Pull out the predicates if the plan is resolved.
+ // Step 2: If the subquery plan is fully resolved, pull the outer references and record
+ // them as children of SubqueryExpression.
if (current.resolved) {
// Make sure the resolved query has the required number of output columns. This is only
// needed for Scalar and IN subqueries.
@@ -1428,34 +1404,37 @@ class Analyzer(
failAnalysis(s"The number of columns in the subquery (${current.output.size}) " +
s"does not match the required number of columns ($requiredColumns)")
}
- // Pullout predicates and construct a new plan.
- f.tupled(rewriteSubQuery(current, plans))
+ // Validate the outer reference and record the outer references as children of
+ // subquery expression.
+ f(current, checkAndGetOuterReferences(current))
} else {
e.withNewPlan(current)
}
}
/**
- * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS
- * expressions into PredicateSubquery expression once the are resolved.
+ * Resolves the subquery. Apart of resolving the subquery and outer references (if any)
+ * in the subquery plan, the children of subquery expression are updated to record the
+ * outer references. This is needed to make sure
+ * (1) The column(s) referred from the outer query are not pruned from the plan during
+ * optimization.
+ * (2) Any aggregate expression(s) that reference outer attributes are pushed down to
+ * outer plan to get evaluated.
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan transformExpressions {
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
- case e @ Exists(sub, exprId) =>
- resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId))
- case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
+ case e @ Exists(sub, _, exprId) if !sub.resolved =>
+ resolveSubQuery(e, plans)(Exists(_, _, exprId))
+ case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved =>
// Get the left hand side expressions.
- val expressions = e match {
+ val expressions = value match {
case cns : CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
- resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) =>
- // Construct the IN conditions.
- val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled)
- PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId)
- }
+ val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId))
+ In(value, Seq(expr))
}
}
@@ -2353,6 +2332,11 @@ class Analyzer(
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
+ // Casts could be added in the subquery plan through the rule TypeCoercion while coercing
+ // the types between the value expression and list query expression of IN expression.
+ // We need to subject the subquery plan through ResolveTimeZone again to setup timezone
+ // information for time zone aware expressions.
+ case e: ListQuery => e.withNewPlan(apply(e.plan))
}
}
}
@@ -2533,3 +2517,67 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
CreateNamedStruct(children.toList)
}
}
+
+/**
+ * The aggregate expressions from subquery referencing outer query block are pushed
+ * down to the outer query block for evaluation. This rule below updates such outer references
+ * as AttributeReference referring attributes from the parent/outer query block.
+ *
+ * For example (SQL):
+ * {{{
+ * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b))
+ * }}}
+ * Plan before the rule.
+ * Project [a#226]
+ * +- Filter exists#245 [min(b#227)#249]
+ * : +- Project [1 AS 1#247]
+ * : +- Filter (d#238 < min(outer(b#227))) <-----
+ * : +- SubqueryAlias r
+ * : +- Project [_1#234 AS c#237, _2#235 AS d#238]
+ * : +- LocalRelation [_1#234, _2#235]
+ * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249]
+ * +- SubqueryAlias l
+ * +- Project [_1#223 AS a#226, _2#224 AS b#227]
+ * +- LocalRelation [_1#223, _2#224]
+ * Plan after the rule.
+ * Project [a#226]
+ * +- Filter exists#245 [min(b#227)#249]
+ * : +- Project [1 AS 1#247]
+ * : +- Filter (d#238 < outer(min(b#227)#249)) <-----
+ * : +- SubqueryAlias r
+ * : +- Project [_1#234 AS c#237, _2#235 AS d#238]
+ * : +- LocalRelation [_1#234, _2#235]
+ * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249]
+ * +- SubqueryAlias l
+ * +- Project [_1#223 AS a#226, _2#224 AS b#227]
+ * +- LocalRelation [_1#223, _2#224]
+ */
+object UpdateOuterReferences extends Rule[LogicalPlan] {
+ private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child }
+
+ private def updateOuterReferenceInSubquery(
+ plan: LogicalPlan,
+ refExprs: Seq[Expression]): LogicalPlan = {
+ plan transformAllExpressions { case e =>
+ val outerAlias =
+ refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e)))
+ outerAlias match {
+ case Some(a: Alias) => OuterReference(a.toAttribute)
+ case _ => e
+ }
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan transform {
+ case f @ Filter(_, a: Aggregate) if f.resolved =>
+ f transformExpressions {
+ case s: SubqueryExpression if s.children.nonEmpty =>
+ // Collect the aliases from output of aggregate.
+ val outerAliases = a.aggregateExpressions collect { case a: Alias => a }
+ // Update the subquery plan to record the OuterReference to point to outer query plan.
+ s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases))
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index d32fbeb4e9..da0c6b098f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -133,10 +134,8 @@ trait CheckAnalysis extends PredicateHelper {
if (conditions.isEmpty && query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
- } else if (conditions.nonEmpty) {
- // Collect the columns from the subquery for further checking.
- var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)
-
+ }
+ else if (conditions.nonEmpty) {
def checkAggregate(agg: Aggregate): Unit = {
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates containing exactly one aggregate expression.
@@ -152,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper {
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
+ // Collect the local references from the correlated predicate in the subquery.
+ val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
+ .filterNot(conditions.flatMap(_.references).contains)
val correlatedCols = AttributeSet(subqueryColumns)
val invalidCols = groupByCols -- correlatedCols
// GROUP BY columns must be a subset of columns in the predicates
@@ -167,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper {
// For projects, do the necessary mapping and skip to its child.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQuery(s.child)
- case p: Project =>
- // SPARK-18814: Map any aliases to their AttributeReference children
- // for the checking in the Aggregate operators below this Project.
- subqueryColumns = subqueryColumns.map {
- xs => p.projectList.collectFirst {
- case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
- child
- }.getOrElse(xs)
- }
-
- cleanQuery(p.child)
+ case p: Project => cleanQuery(p.child)
case child => child
}
@@ -211,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper {
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
- case Filter(condition, _) =>
- splitConjunctivePredicates(condition).foreach {
- case _: PredicateSubquery | Not(_: PredicateSubquery) =>
- case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
- failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" +
- s" conditions: $e")
- case e =>
- }
+ case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) =>
+ failAnalysis("Null-aware predicate sub-queries cannot be used in nested " +
+ s"conditions: $condition")
case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
@@ -306,8 +293,11 @@ trait CheckAnalysis extends PredicateHelper {
s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p")
}
- case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
- failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
+ case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) =>
+ p match {
+ case _: Filter => // Ok
+ case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
+ }
case _: Union | _: SetOperation if operator.children.length > 1 =>
def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 2c00957bd6..768897dc07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -109,6 +109,28 @@ object TypeCoercion {
}
/**
+ * This function determines the target type of a comparison operator when one operand
+ * is a String and the other is not. It also handles when one op is a Date and the
+ * other is a Timestamp by making the target type to be String.
+ */
+ val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
+ // We should cast all relative timestamp/date/string comparison into string comparisons
+ // This behaves as a user would expect because timestamp strings sort lexicographically.
+ // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
+ case (StringType, DateType) => Some(StringType)
+ case (DateType, StringType) => Some(StringType)
+ case (StringType, TimestampType) => Some(StringType)
+ case (TimestampType, StringType) => Some(StringType)
+ case (TimestampType, DateType) => Some(StringType)
+ case (DateType, TimestampType) => Some(StringType)
+ case (StringType, NullType) => Some(StringType)
+ case (NullType, StringType) => Some(StringType)
+ case (l: StringType, r: AtomicType) if r != StringType => Some(r)
+ case (l: AtomicType, r: StringType) if (l != StringType) => Some(l)
+ case (l, r) => None
+ }
+
+ /**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
* i.e. the main difference with [[findTightestCommonType]] is that here we allow some
@@ -305,6 +327,14 @@ object TypeCoercion {
* Promotes strings that appear in arithmetic expressions.
*/
object PromoteStrings extends Rule[LogicalPlan] {
+ private def castExpr(expr: Expression, targetType: DataType): Expression = {
+ (expr.dataType, targetType) match {
+ case (NullType, dt) => Literal.create(null, targetType)
+ case (l, dt) if (l != dt) => Cast(expr, targetType)
+ case _ => expr
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -321,37 +351,10 @@ object TypeCoercion {
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))
- // We should cast all relative timestamp/date/string comparison into string comparisons
- // This behaves as a user would expect because timestamp strings sort lexicographically.
- // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
- case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
- p.makeCopy(Array(left, Cast(right, StringType)))
- case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
- p.makeCopy(Array(Cast(left, StringType), right))
- case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
- p.makeCopy(Array(left, Cast(right, StringType)))
- case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
- p.makeCopy(Array(Cast(left, StringType), right))
-
- // Comparisons between dates and timestamps.
- case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
- p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
- case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
- p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
-
- // Checking NullType
- case p @ BinaryComparison(left @ StringType(), right @ NullType()) =>
- p.makeCopy(Array(left, Literal.create(null, StringType)))
- case p @ BinaryComparison(left @ NullType(), right @ StringType()) =>
- p.makeCopy(Array(Literal.create(null, StringType), right))
-
- // When compare string with atomic type, case string to that type.
- case p @ BinaryComparison(left @ StringType(), right @ AtomicType())
- if right.dataType != StringType =>
- p.makeCopy(Array(Cast(left, right.dataType), right))
- case p @ BinaryComparison(left @ AtomicType(), right @ StringType())
- if left.dataType != StringType =>
- p.makeCopy(Array(left, Cast(right, left.dataType)))
+ case p @ BinaryComparison(left, right)
+ if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
+ val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
+ p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
@@ -365,17 +368,72 @@ object TypeCoercion {
}
/**
- * Convert the value and in list expressions to the common operator type
- * by looking at all the argument types and finding the closest one that
- * all the arguments can be cast to. When no common operator type is found
- * the original expression will be returned and an Analysis Exception will
- * be raised at type checking phase.
+ * Handles type coercion for both IN expression with subquery and IN
+ * expressions without subquery.
+ * 1. In the first case, find the common type by comparing the left hand side (LHS)
+ * expression types against corresponding right hand side (RHS) expression derived
+ * from the subquery expression's plan output. Inject appropriate casts in the
+ * LHS and RHS side of IN expression.
+ *
+ * 2. In the second case, convert the value and in list expressions to the
+ * common operator type by looking at all the argument types and finding
+ * the closest one that all the arguments can be cast to. When no common
+ * operator type is found the original expression will be returned and an
+ * Analysis Exception will be raised at the type checking phase.
*/
object InConversion extends Rule[LogicalPlan] {
+ private def flattenExpr(expr: Expression): Seq[Expression] = {
+ expr match {
+ // Multi columns in IN clause is represented as a CreateNamedStruct.
+ // flatten the named struct to get the list of expressions.
+ case cns: CreateNamedStruct => cns.valExprs
+ case expr => Seq(expr)
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
+ // Handle type casting required between value expression and subquery output
+ // in IN subquery.
+ case i @ In(a, Seq(ListQuery(sub, children, exprId)))
+ if !i.resolved && flattenExpr(a).length == sub.output.length =>
+ // LHS is the value expression of IN subquery.
+ val lhs = flattenExpr(a)
+
+ // RHS is the subquery output.
+ val rhs = sub.output
+
+ val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
+ findCommonTypeForBinaryComparison(l.dataType, r.dataType)
+ .orElse(findTightestCommonType(l.dataType, r.dataType))
+ }
+
+ // The number of columns/expressions must match between LHS and RHS of an
+ // IN subquery expression.
+ if (commonTypes.length == lhs.length) {
+ val castedRhs = rhs.zip(commonTypes).map {
+ case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
+ case (e, _) => e
+ }
+ val castedLhs = lhs.zip(commonTypes).map {
+ case (e, dt) if e.dataType != dt => Cast(e, dt)
+ case (e, _) => e
+ }
+
+ // Before constructing the In expression, wrap the multi values in LHS
+ // in a CreatedNamedStruct.
+ val newLhs = castedLhs match {
+ case Seq(lhs) => lhs
+ case _ => CreateStruct(castedLhs)
+ }
+
+ In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
+ } else {
+ i
+ }
+
case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index ac56ff13fa..e5d1a1e299 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -123,19 +123,44 @@ case class Not(child: Expression)
*/
@ExpressionDescription(
usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.")
-case class In(value: Expression, list: Seq[Expression]) extends Predicate
- with ImplicitCastInputTypes {
+case class In(value: Expression, list: Seq[Expression]) extends Predicate {
require(list != null, "list should not be null")
+ override def checkInputDataTypes(): TypeCheckResult = {
+ list match {
+ case ListQuery(sub, _, _) :: Nil =>
+ val valExprs = value match {
+ case cns: CreateNamedStruct => cns.valExprs
+ case expr => Seq(expr)
+ }
- override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
+ val mismatchedColumns = valExprs.zip(sub.output).flatMap {
+ case (l, r) if l.dataType != r.dataType =>
+ s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
+ case _ => None
+ }
- override def checkInputDataTypes(): TypeCheckResult = {
- if (list.exists(l => l.dataType != value.dataType)) {
- TypeCheckResult.TypeCheckFailure(
- "Arguments must be same type")
- } else {
- TypeCheckResult.TypeCheckSuccess
+ if (mismatchedColumns.nonEmpty) {
+ TypeCheckResult.TypeCheckFailure(
+ s"""
+ |The data type of one or more elements in the left hand side of an IN subquery
+ |is not compatible with the data type of the output of the subquery
+ |Mismatched columns:
+ |[${mismatchedColumns.mkString(", ")}]
+ |Left side:
+ |[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
+ |Right side:
+ |[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
+ """.stripMargin)
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ case _ =>
+ if (list.exists(l => l.dataType != value.dataType)) {
+ TypeCheckResult.TypeCheckFailure("Arguments must be same type")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index e2e7d98e33..ad11700fa2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.types._
/**
@@ -40,19 +43,184 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
/**
* A base interface for expressions that contain a [[LogicalPlan]].
*/
-abstract class SubqueryExpression extends PlanExpression[LogicalPlan] {
+abstract class SubqueryExpression(
+ plan: LogicalPlan,
+ children: Seq[Expression],
+ exprId: ExprId) extends PlanExpression[LogicalPlan] {
+
+ override lazy val resolved: Boolean = childrenResolved && plan.resolved
+ override lazy val references: AttributeSet =
+ if (plan.resolved) super.references -- plan.outputSet else super.references
override def withNewPlan(plan: LogicalPlan): SubqueryExpression
+ override def semanticEquals(o: Expression): Boolean = o match {
+ case p: SubqueryExpression =>
+ this.getClass.getName.equals(p.getClass.getName) && plan.sameResult(p.plan) &&
+ children.length == p.children.length &&
+ children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
+ case _ => false
+ }
}
object SubqueryExpression {
+ /**
+ * Returns true when an expression contains an IN or EXISTS subquery and false otherwise.
+ */
+ def hasInOrExistsSubquery(e: Expression): Boolean = {
+ e.find {
+ case _: ListQuery | _: Exists => true
+ case _ => false
+ }.isDefined
+ }
+
+ /**
+ * Returns true when an expression contains a subquery that has outer reference(s). The outer
+ * reference attributes are kept as children of subquery expression by
+ * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]]
+ */
def hasCorrelatedSubquery(e: Expression): Boolean = {
e.find {
- case e: SubqueryExpression if e.children.nonEmpty => true
+ case s: SubqueryExpression => s.children.nonEmpty
case _ => false
}.isDefined
}
}
+object SubExprUtils extends PredicateHelper {
+ /**
+ * Returns true when an expression contains correlated predicates i.e outer references and
+ * returns false otherwise.
+ */
+ def containsOuter(e: Expression): Boolean = {
+ e.find(_.isInstanceOf[OuterReference]).isDefined
+ }
+
+ /**
+ * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could
+ * turn the null-aware predicate into not-null-aware predicate.
+ */
+ def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = {
+ splitConjunctivePredicates(condition).exists {
+ case _: Exists | Not(_: Exists) | In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) =>
+ false
+ case e => e.find { x =>
+ x.isInstanceOf[Not] && e.find {
+ case In(_, Seq(_: ListQuery)) => true
+ case _ => false
+ }.isDefined
+ }.isDefined
+ }
+
+ }
+
+ /**
+ * Returns an expression after removing the OuterReference shell.
+ */
+ def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r }
+
+ /**
+ * Returns the list of expressions after removing the OuterReference shell from each of
+ * the expression.
+ */
+ def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference)
+
+ /**
+ * Returns the logical plan after removing the OuterReference shell from all the expressions
+ * of the input logical plan.
+ */
+ def stripOuterReferences(p: LogicalPlan): LogicalPlan = {
+ p.transformAllExpressions {
+ case OuterReference(a) => a
+ }
+ }
+
+ /**
+ * Given a logical plan, returns TRUE if it has an outer reference and false otherwise.
+ */
+ def hasOuterReferences(plan: LogicalPlan): Boolean = {
+ plan.find {
+ case f: Filter => containsOuter(f.condition)
+ case other => false
+ }.isDefined
+ }
+
+ /**
+ * Given a list of expressions, returns the expressions which have outer references. Aggregate
+ * expressions are treated in a special way. If the children of aggregate expression contains an
+ * outer reference, then the entire aggregate expression is marked as an outer reference.
+ * Example (SQL):
+ * {{{
+ * SELECT a FROM l GROUP by 1 HAVING EXISTS (SELECT 1 FROM r WHERE d < min(b))
+ * }}}
+ * In the above case, we want to mark the entire min(b) as an outer reference
+ * OuterReference(min(b)) instead of min(OuterReference(b)).
+ * TODO: Currently we don't allow deep correlation. Also, we don't allow mixing of
+ * outer references and local references under an aggregate expression.
+ * For example (SQL):
+ * {{{
+ * SELECT .. FROM p1
+ * WHERE EXISTS (SELECT ...
+ * FROM p2
+ * WHERE EXISTS (SELECT ...
+ * FROM sq
+ * WHERE min(p1.a + p2.b) = sq.c))
+ *
+ * SELECT .. FROM p1
+ * WHERE EXISTS (SELECT ...
+ * FROM p2
+ * WHERE EXISTS (SELECT ...
+ * FROM sq
+ * WHERE min(p1.a) + max(p2.b) = sq.c))
+ *
+ * SELECT .. FROM p1
+ * WHERE EXISTS (SELECT ...
+ * FROM p2
+ * WHERE EXISTS (SELECT ...
+ * FROM sq
+ * WHERE min(p1.a + sq.c) > 1))
+ * }}}
+ * The code below needs to change when we support the above cases.
+ */
+ def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = {
+ val outerExpressions = ArrayBuffer.empty[Expression]
+ conditions foreach { expr =>
+ expr transformDown {
+ case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) =>
+ val newExpr = stripOuterReference(a)
+ outerExpressions += newExpr
+ newExpr
+ case OuterReference(e) =>
+ outerExpressions += e
+ e
+ }
+ }
+ outerExpressions
+ }
+
+ /**
+ * Returns all the expressions that have outer references from a logical plan. Currently only
+ * Filter operator can host outer references.
+ */
+ def getOuterReferences(plan: LogicalPlan): Seq[Expression] = {
+ val conditions = plan.collect { case Filter(cond, _) => cond }
+ getOuterReferences(conditions)
+ }
+
+ /**
+ * Returns the correlated predicates from a logical plan. The OuterReference wrapper
+ * is removed before returning the predicate to the caller.
+ */
+ def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = {
+ val conditions = plan.collect { case Filter(cond, _) => cond }
+ conditions.flatMap { e =>
+ val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter)
+ stripOuterReferences(correlated) match {
+ case Nil => None
+ case xs => xs
+ }
+ }
+ }
+}
+
/**
* A subquery that will return only one row and one column. This will be converted into a physical
* scalar subquery during planning.
@@ -63,14 +231,8 @@ case class ScalarSubquery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
- extends SubqueryExpression with Unevaluable {
- override lazy val resolved: Boolean = childrenResolved && plan.resolved
- override lazy val references: AttributeSet = {
- if (plan.resolved) super.references -- plan.outputSet
- else super.references
- }
+ extends SubqueryExpression(plan, children, exprId) with Unevaluable {
override def dataType: DataType = plan.schema.fields.head.dataType
- override def foldable: Boolean = false
override def nullable: Boolean = true
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
@@ -79,60 +241,13 @@ case class ScalarSubquery(
object ScalarSubquery {
def hasCorrelatedScalarSubquery(e: Expression): Boolean = {
e.find {
- case e: ScalarSubquery if e.children.nonEmpty => true
+ case s: ScalarSubquery => s.children.nonEmpty
case _ => false
}.isDefined
}
}
/**
- * A predicate subquery checks the existence of a value in a sub-query. We currently only allow
- * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will
- * be rewritten into a left semi/anti join during analysis.
- */
-case class PredicateSubquery(
- plan: LogicalPlan,
- children: Seq[Expression] = Seq.empty,
- nullAware: Boolean = false,
- exprId: ExprId = NamedExpression.newExprId)
- extends SubqueryExpression with Predicate with Unevaluable {
- override lazy val resolved = childrenResolved && plan.resolved
- override lazy val references: AttributeSet = super.references -- plan.outputSet
- override def nullable: Boolean = nullAware
- override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan)
- override def semanticEquals(o: Expression): Boolean = o match {
- case p: PredicateSubquery =>
- plan.sameResult(p.plan) && nullAware == p.nullAware &&
- children.length == p.children.length &&
- children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
- case _ => false
- }
- override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
-}
-
-object PredicateSubquery {
- def hasPredicateSubquery(e: Expression): Boolean = {
- e.find {
- case _: PredicateSubquery | _: ListQuery | _: Exists => true
- case _ => false
- }.isDefined
- }
-
- /**
- * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could
- * turn the null-aware predicate into not-null-aware predicate.
- */
- def hasNullAwarePredicateWithinNot(e: Expression): Boolean = {
- e.find{ x =>
- x.isInstanceOf[Not] && e.find {
- case p: PredicateSubquery => p.nullAware
- case _ => false
- }.isDefined
- }.isDefined
- }
-}
-
-/**
* A [[ListQuery]] expression defines the query which we want to search in an IN subquery
* expression. It should and can only be used in conjunction with an IN expression.
*
@@ -144,18 +259,20 @@ object PredicateSubquery {
* FROM b)
* }}}
*/
-case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
- extends SubqueryExpression with Unevaluable {
- override lazy val resolved = false
- override def children: Seq[Expression] = Seq.empty
- override def dataType: DataType = ArrayType(NullType)
+case class ListQuery(
+ plan: LogicalPlan,
+ children: Seq[Expression] = Seq.empty,
+ exprId: ExprId = NamedExpression.newExprId)
+ extends SubqueryExpression(plan, children, exprId) with Unevaluable {
+ override def dataType: DataType = plan.schema.fields.head.dataType
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
- override def toString: String = s"list#${exprId.id}"
+ override def toString: String = s"list#${exprId.id} $conditionString"
}
/**
* The [[Exists]] expression checks if a row exists in a subquery given some correlated condition.
+ *
* For example (SQL):
* {{{
* SELECT *
@@ -165,11 +282,12 @@ case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExpr
* WHERE b.id = a.id)
* }}}
*/
-case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
- extends SubqueryExpression with Predicate with Unevaluable {
- override lazy val resolved = false
- override def children: Seq[Expression] = Seq.empty
+case class Exists(
+ plan: LogicalPlan,
+ children: Seq[Expression] = Seq.empty,
+ exprId: ExprId = NamedExpression.newExprId)
+ extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable {
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
- override def toString: String = s"exists#${exprId.id}"
+ override def toString: String = s"exists#${exprId.id} $conditionString"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index caafa1c134..e9dbded3d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -68,6 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
// since the other rules might make two separate Unions operators adjacent.
Batch("Union", Once,
CombineUnions) ::
+ Batch("Pullup Correlated Expressions", Once,
+ PullupCorrelatedPredicates) ::
Batch("Subquery", Once,
OptimizeSubqueries) ::
Batch("Replace Operators", fixedPoint,
@@ -885,7 +887,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = {
val attributes = plan.outputSet
val matched = condition.find {
- case PredicateSubquery(p, _, _, _) => p.outputSet.intersect(attributes).nonEmpty
+ case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty
case _ => false
}
matched.isEmpty
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index fb7ce6aece..ba3fd1d5f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -41,10 +42,17 @@ import org.apache.spark.sql.types._
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
+ private def getValueExpression(e: Expression): Seq[Expression] = {
+ e match {
+ case cns : CreateNamedStruct => cns.valExprs
+ case expr => Seq(expr)
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
- splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery)
+ splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery)
// Construct the pruned filter condition.
val newFilter: LogicalPlan = withoutSubquery match {
@@ -54,20 +62,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
- case (p, PredicateSubquery(sub, conditions, _, _)) =>
+ case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
- case (p, Not(PredicateSubquery(sub, conditions, false, _))) =>
+ case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftAnti, joinCond)
- case (p, Not(PredicateSubquery(sub, conditions, true, _))) =>
+ case (p, In(value, Seq(ListQuery(sub, conditions, _)))) =>
+ val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
+ Join(outerPlan, sub, LeftSemi, joinCond)
+ case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
- val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
+ val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
// (a1,b1,...) = (a2,b2,...)
@@ -83,11 +96,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Given a predicate expression and an input plan, it rewrites
- * any embedded existential sub-query into an existential join.
- * It returns the rewritten expression together with the updated plan.
- * Currently, it does not support null-aware joins. Embedded NOT IN predicates
- * are blocked in the Analyzer.
+ * Given a predicate expression and an input plan, it rewrites any embedded existential sub-query
+ * into an existential join. It returns the rewritten expression together with the updated plan.
+ * Currently, it does not support NOT IN nested inside a NOT expression. This case is blocked in
+ * the Analyzer.
*/
private def rewriteExistentialExpr(
exprs: Seq[Expression],
@@ -95,17 +107,138 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
var newPlan = plan
val newExprs = exprs.map { e =>
e transformUp {
- case PredicateSubquery(sub, conditions, nullAware, _) =>
- // TODO: support null-aware join
+ case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
- }
+ case In(value, Seq(ListQuery(sub, conditions, _))) =>
+ val exists = AttributeReference("exists", BooleanType, nullable = false)()
+ val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+ val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
+ newPlan = Join(newPlan, sub, ExistenceJoin(exists), newConditions)
+ exists
+ }
}
(newExprs.reduceOption(And), newPlan)
}
}
+ /**
+ * Pull out all (outer) correlated predicates from a given subquery. This method removes the
+ * correlated predicates from subquery [[Filter]]s and adds the references of these predicates
+ * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to
+ * be able to evaluate the predicates at the top level.
+ *
+ * TODO: Look to merge this rule with RewritePredicateSubquery.
+ */
+object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper {
+ /**
+ * Returns the correlated predicates and a updated plan that removes the outer references.
+ */
+ private def pullOutCorrelatedPredicates(
+ sub: LogicalPlan,
+ outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
+ val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
+
+ /** Determine which correlated predicate references are missing from this plan. */
+ def missingReferences(p: LogicalPlan): AttributeSet = {
+ val localPredicateReferences = p.collect(predicateMap)
+ .flatten
+ .map(_.references)
+ .reduceOption(_ ++ _)
+ .getOrElse(AttributeSet.empty)
+ localPredicateReferences -- p.outputSet
+ }
+
+ // Simplify the predicates before pulling them out.
+ val transformed = BooleanSimplification(sub) transformUp {
+ case f @ Filter(cond, child) =>
+ val (correlated, local) =
+ splitConjunctivePredicates(cond).partition(containsOuter)
+
+ // Rewrite the filter without the correlated predicates if any.
+ correlated match {
+ case Nil => f
+ case xs if local.nonEmpty =>
+ val newFilter = Filter(local.reduce(And), child)
+ predicateMap += newFilter -> xs
+ newFilter
+ case xs =>
+ predicateMap += child -> xs
+ child
+ }
+ case p @ Project(expressions, child) =>
+ val referencesToAdd = missingReferences(p)
+ if (referencesToAdd.nonEmpty) {
+ Project(expressions ++ referencesToAdd, child)
+ } else {
+ p
+ }
+ case a @ Aggregate(grouping, expressions, child) =>
+ val referencesToAdd = missingReferences(a)
+ if (referencesToAdd.nonEmpty) {
+ Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child)
+ } else {
+ a
+ }
+ case p =>
+ p
+ }
+
+ // Make sure the inner and the outer query attributes do not collide.
+ // In case of a collision, change the subquery plan's output to use
+ // different attribute by creating alias(s).
+ val baseConditions = predicateMap.values.flatten.toSeq
+ val (newPlan, newCond) = if (outer.nonEmpty) {
+ val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
+ val duplicates = transformed.outputSet.intersect(outputSet)
+ val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) {
+ val aliasMap = AttributeMap(duplicates.map { dup =>
+ dup -> Alias(dup, dup.toString)()
+ }.toSeq)
+ val aliasedExpressions = transformed.output.map { ref =>
+ aliasMap.getOrElse(ref, ref)
+ }
+ val aliasedProjection = Project(aliasedExpressions, transformed)
+ val aliasedConditions = baseConditions.map(_.transform {
+ case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
+ })
+ (aliasedProjection, aliasedConditions)
+ } else {
+ (transformed, baseConditions)
+ }
+ (plan, stripOuterReferences(deDuplicatedConditions))
+ } else {
+ (transformed, stripOuterReferences(baseConditions))
+ }
+ (newPlan, newCond)
+ }
+
+ private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = {
+ plan transformExpressions {
+ case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
+ val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+ ScalarSubquery(newPlan, newCond, exprId)
+ case Exists(sub, children, exprId) if children.nonEmpty =>
+ val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+ Exists(newPlan, newCond, exprId)
+ case ListQuery(sub, _, exprId) =>
+ val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+ ListQuery(newPlan, newCond, exprId)
+ }
+ }
+
+ /**
+ * Pull up the correlated predicates and rewrite all subqueries in an operator tree..
+ */
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case f @ Filter(_, a: Aggregate) =>
+ rewriteSubQueries(f, Seq(a, a.child))
+ // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
+ case q: UnaryNode =>
+ rewriteSubQueries(q, q.children)
+ }
+}
/**
* This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index c5e877d128..d2ebca5a83 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -530,7 +530,7 @@ class AnalysisErrorSuite extends AnalysisTest {
Exists(
Join(
LocalRelation(b),
- Filter(EqualTo(OuterReference(a), c), LocalRelation(c)),
+ Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)),
LeftOuter,
Option(EqualTo(b, c)))),
LocalRelation(a))
@@ -539,7 +539,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan2 = Filter(
Exists(
Join(
- Filter(EqualTo(OuterReference(a), c), LocalRelation(c)),
+ Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)),
LocalRelation(b),
RightOuter,
Option(EqualTo(b, c)))),
@@ -547,14 +547,15 @@ class AnalysisErrorSuite extends AnalysisTest {
assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil)
val plan3 = Filter(
- Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))),
+ Exists(Union(LocalRelation(b),
+ Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)))),
LocalRelation(a))
assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
val plan4 = Filter(
Exists(
Limit(1,
- Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))
+ Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
),
LocalRelation(a))
assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil)
@@ -562,7 +563,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan5 = Filter(
Exists(
Sample(0.0, 0.5, false, 1L,
- Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b)
+ Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b)
),
LocalRelation(a))
assertAnalysisError(plan5,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
index 4aafb2b83f..5569312143 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
@@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest {
val t2 = LocalRelation(b)
test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
- val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1)
+ val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1)
val m = intercept[AnalysisException] {
SimpleAnalyzer.ResolveSubquery(expr)
}.getMessage
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index e9b7a0c6ad..5eb31413ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -43,8 +43,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
e.copy(exprId = ExprId(0))
case l: ListQuery =>
l.copy(exprId = ExprId(0))
- case p: PredicateSubquery =>
- p.copy(exprId = ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 730ca27f82..58be2d1da2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -144,9 +144,6 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
ScalarSubquery(
SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan),
subquery.exprId)
- case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) =>
- val executedPlan = new QueryExecution(sparkSession, query).executedPlan
- InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
}
}
}
diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
index 50ae01e181..f7bbb35aad 100644
--- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
@@ -46,7 +46,7 @@ and t2b = (select max(avg)
struct<>
-- !query 3 output
org.apache.spark.sql.AnalysisException
-expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
+grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.;
-- !query 4
@@ -63,4 +63,4 @@ where t1a in (select min(t2a)
struct<>
-- !query 4 output
org.apache.spark.sql.AnalysisException
-resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)];
+resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 25dbecb589..6f1cd49c08 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -622,7 +622,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") {
checkAnswer(
- sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"),
+ sql(
+ """
+ |select l.b, (select (r.c + count(*)) is null
+ |from r
+ |where l.a = r.c group by r.c) from l
+ """.stripMargin),
Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
}