diff options
Diffstat (limited to 'sql/catalyst')
10 files changed, 667 insertions, 294 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 93666f1495..a3764d8c84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -21,12 +21,13 @@ import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} @@ -162,6 +163,8 @@ class Analyzer( FixNullability), Batch("ResolveTimeZone", Once, ResolveTimeZone), + Batch("Subquery", Once, + UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -710,13 +713,72 @@ class Analyzer( } transformUp { case other => other transformExpressions { case a: Attribute => - attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } newRight } } + private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + } + + /** + * The outer plan may have been de-duplicated and the function below updates the + * outer references to refer to the de-duplicated attributes. + * + * For example (SQL): + * {{{ + * SELECT * FROM t1 + * INTERSECT + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * }}} + * Plan before resolveReference rule. + * 'Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- 'Project [*] + * +- Filter exists#257 [c1#245] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#245) = c1#251) + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#245,c2#246] parquet + * Plan after the resolveReference rule. + * Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- Project [c1#259, c2#260] + * +- Filter exists#257 [c1#259] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#259) = c1#251) => Updated + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. + */ + private def dedupOuterReferencesInSubquery( + plan: LogicalPlan, + attrMap: AttributeMap[Attribute]): LogicalPlan = { + plan transformDown { case currentFragment => + currentFragment transformExpressions { + case OuterReference(a: Attribute) => + OuterReference(dedupAttr(a, attrMap)) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -1132,28 +1194,21 @@ class Analyzer( } /** - * Pull out all (outer) correlated predicates from a given subquery. This method removes the - * correlated predicates from subquery [[Filter]]s and adds the references of these predicates - * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to - * be able to evaluate the predicates at the top level. - * - * This method returns the rewritten subquery and correlated predicates. + * Validates to make sure the outer references appearing inside the subquery + * are legal. This function also returns the list of expressions + * that contain outer references. These outer references would be kept as children + * of subquery expressions by the caller of this function. */ - private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { - val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { + val outerReferences = ArrayBuffer.empty[Expression] // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (p.collectFirst(predicateMap).nonEmpty) { + if (hasOuterReferences(p)) { failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } - // Helper function for locating outer references. - def containsOuter(e: Expression): Boolean = { - e.find(_.isInstanceOf[OuterReference]).isDefined - } - // Make sure a plan's expressions do not contain outer references def failOnOuterReference(p: LogicalPlan): Unit = { if (p.expressions.exists(containsOuter)) { @@ -1194,20 +1249,11 @@ class Analyzer( } } - /** Determine which correlated predicate references are missing from this plan. */ - def missingReferences(p: LogicalPlan): AttributeSet = { - val localPredicateReferences = p.collect(predicateMap) - .flatten - .map(_.references) - .reduceOption(_ ++ _) - .getOrElse(AttributeSet.empty) - localPredicateReferences -- p.outputSet - } - var foundNonEqualCorrelatedPred : Boolean = false - // Simplify the predicates before pulling them out. - val transformed = BooleanSimplification(sub) transformUp { + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: @@ -1229,80 +1275,48 @@ class Analyzer( // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case p: BroadcastHint => - p - case p: Distinct => - p - case p: LeafNode => - p - case p: Repartition => - p - case p: SubqueryAlias => - p + case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. - case p: Sort => - failOnOuterReference(p) - p - case p: RepartitionByExpression => - failOnOuterReference(p) - p + case s: Sort => + failOnOuterReference(s) + case r: RepartitionByExpression => + failOnOuterReference(r) // Category 3: // Filter is one of the two operators allowed to host correlated expressions. // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f @ Filter(cond, child) => + case f: Filter => // Find all predicates with an outer reference. - val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) // Find any non-equality correlated predicates foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { case _: EqualTo | _: EqualNullSafe => false case _ => true } - - // Rewrite the filter without the correlated predicates if any. - correlated match { - case Nil => f - case xs if local.nonEmpty => - val newFilter = Filter(local.reduce(And), child) - predicateMap += newFilter -> xs - newFilter - case xs => - predicateMap += child -> xs - child - } + // The aggregate expressions are treated in a special way by getOuterReferences. If the + // aggregate expression contains only outer reference attributes then the entire aggregate + // expression is isolated as an OuterReference. + // i.e min(OuterReference(b)) => OuterReference(min(b)) + outerReferences ++= getOuterReferences(correlated) // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. - case p @ Project(expressions, child) => + case p: Project => failOnOuterReference(p) - val referencesToAdd = missingReferences(p) - if (referencesToAdd.nonEmpty) { - Project(expressions ++ referencesToAdd, child) - } else { - p - } - // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains // only equality correlated predicates. // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. - case a @ Aggregate(grouping, expressions, child) => + case a: Aggregate => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - val referencesToAdd = missingReferences(a) - if (referencesToAdd.nonEmpty) { - Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) - } else { - a - } - // Join can host correlated expressions. case j @ Join(left, right, joinType, _) => joinType match { @@ -1332,7 +1346,6 @@ class Analyzer( case _ => failOnOuterReferenceInSubTree(j) } - j // Generator with join=true, i.e., expressed with // LATERAL VIEW [OUTER], similar to inner join, @@ -1340,9 +1353,8 @@ class Analyzer( // but must not host any outer references. // Note: // Generator with join=false is treated as Category 4. - case p @ Generate(generator, true, _, _, _, _) => - failOnOuterReference(p) - p + case g: Generate if g.join => + failOnOuterReference(g) // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only @@ -1350,54 +1362,17 @@ class Analyzer( // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - p } - (transformed, predicateMap.values.flatten.toSeq) + outerReferences } /** - * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same - * attributes. - */ - private def rewriteSubQuery( - sub: LogicalPlan, - outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { - // Pull out the tagged predicates and rewrite the subquery in the process. - val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) - - // Make sure the inner and the outer query attributes do not collide. - val outputSet = outer.map(_.outputSet).reduce(_ ++ _) - val duplicates = basePlan.outputSet.intersect(outputSet) - val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = basePlan.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, basePlan) - val aliasedConditions = baseConditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (basePlan, baseConditions) - } - // Remove outer references from the correlated predicates. We wait with extracting - // these until collisions between the inner and outer query attributes have been - // solved. - val conditions = deDuplicatedConditions.map(_.transform { - case OuterReference(ref) => ref - }) - (plan, conditions) - } - - /** - * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method + * Resolves the subquery. The subquery is resolved using its outer plans. This method * will resolve the subquery by alternating between the regular analyzer and by applying the * resolveOuterReferences rule. * - * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved. + * Outer references from the correlated predicates are updated as children of + * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, @@ -1420,7 +1395,8 @@ class Analyzer( } } while (!current.resolved && !current.fastEquals(previous)) - // Step 2: Pull out the predicates if the plan is resolved. + // Step 2: If the subquery plan is fully resolved, pull the outer references and record + // them as children of SubqueryExpression. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only // needed for Scalar and IN subqueries. @@ -1428,34 +1404,37 @@ class Analyzer( failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + s"does not match the required number of columns ($requiredColumns)") } - // Pullout predicates and construct a new plan. - f.tupled(rewriteSubQuery(current, plans)) + // Validate the outer reference and record the outer references as children of + // subquery expression. + f(current, checkAndGetOuterReferences(current)) } else { e.withNewPlan(current) } } /** - * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS - * expressions into PredicateSubquery expression once the are resolved. + * Resolves the subquery. Apart of resolving the subquery and outer references (if any) + * in the subquery plan, the children of subquery expression are updated to record the + * outer references. This is needed to make sure + * (1) The column(s) referred from the outer query are not pruned from the plan during + * optimization. + * (2) Any aggregate expression(s) that reference outer attributes are pushed down to + * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, exprId) => - resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) - case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => + case e @ Exists(sub, _, exprId) if !sub.resolved => + resolveSubQuery(e, plans)(Exists(_, _, exprId)) + case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => // Get the left hand side expressions. - val expressions = e match { + val expressions = value match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => - // Construct the IN conditions. - val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) - PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) - } + val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + In(value, Seq(expr)) } } @@ -2353,6 +2332,11 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) } } } @@ -2533,3 +2517,67 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { CreateNamedStruct(children.toList) } } + +/** + * The aggregate expressions from subquery referencing outer query block are pushed + * down to the outer query block for evaluation. This rule below updates such outer references + * as AttributeReference referring attributes from the parent/outer query block. + * + * For example (SQL): + * {{{ + * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b)) + * }}} + * Plan before the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < min(outer(b#227))) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + * Plan after the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < outer(min(b#227)#249)) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + */ +object UpdateOuterReferences extends Rule[LogicalPlan] { + private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } + + private def updateOuterReferenceInSubquery( + plan: LogicalPlan, + refExprs: Seq[Expression]): LogicalPlan = { + plan transformAllExpressions { case e => + val outerAlias = + refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) + outerAlias match { + case Some(a: Alias) => OuterReference(a.toAttribute) + case _ => e + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case f @ Filter(_, a: Aggregate) if f.resolved => + f transformExpressions { + case s: SubqueryExpression if s.children.nonEmpty => + // Collect the aliases from output of aggregate. + val outerAliases = a.aggregateExpressions collect { case a: Alias => a } + // Update the subquery plan to record the OuterReference to point to outer query plan. + s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d32fbeb4e9..da0c6b098f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -133,10 +134,8 @@ trait CheckAnalysis extends PredicateHelper { if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") - } else if (conditions.nonEmpty) { - // Collect the columns from the subquery for further checking. - var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) - + } + else if (conditions.nonEmpty) { def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates containing exactly one aggregate expression. @@ -152,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper { // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns // are not part of the correlated columns. val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) val correlatedCols = AttributeSet(subqueryColumns) val invalidCols = groupByCols -- correlatedCols // GROUP BY columns must be a subset of columns in the predicates @@ -167,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper { // For projects, do the necessary mapping and skip to its child. def cleanQuery(p: LogicalPlan): LogicalPlan = p match { case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => - // SPARK-18814: Map any aliases to their AttributeReference children - // for the checking in the Aggregate operators below this Project. - subqueryColumns = subqueryColumns.map { - xs => p.projectList.collectFirst { - case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => - child - }.getOrElse(xs) - } - - cleanQuery(p.child) + case p: Project => cleanQuery(p.child) case child => child } @@ -211,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case Filter(condition, _) => - splitConjunctivePredicates(condition).foreach { - case _: PredicateSubquery | Not(_: PredicateSubquery) => - case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => - failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + - s" conditions: $e") - case e => - } + case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => + failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + + s"conditions: $condition") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( @@ -306,8 +293,11 @@ trait CheckAnalysis extends PredicateHelper { s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") } - case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => - failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => + p match { + case _: Filter => // Ok + case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + } case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2c00957bd6..768897dc07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -109,6 +109,28 @@ object TypeCoercion { } /** + * This function determines the target type of a comparison operator when one operand + * is a String and the other is not. It also handles when one op is a Date and the + * other is a Timestamp by making the target type to be String. + */ + val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + // We should cast all relative timestamp/date/string comparison into string comparisons + // This behaves as a user would expect because timestamp strings sort lexicographically. + // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true + case (StringType, DateType) => Some(StringType) + case (DateType, StringType) => Some(StringType) + case (StringType, TimestampType) => Some(StringType) + case (TimestampType, StringType) => Some(StringType) + case (TimestampType, DateType) => Some(StringType) + case (DateType, TimestampType) => Some(StringType) + case (StringType, NullType) => Some(StringType) + case (NullType, StringType) => Some(StringType) + case (l: StringType, r: AtomicType) if r != StringType => Some(r) + case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l, r) => None + } + + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * * i.e. the main difference with [[findTightestCommonType]] is that here we allow some @@ -305,6 +327,14 @@ object TypeCoercion { * Promotes strings that appear in arithmetic expressions. */ object PromoteStrings extends Rule[LogicalPlan] { + private def castExpr(expr: Expression, targetType: DataType): Expression = { + (expr.dataType, targetType) match { + case (NullType, dt) => Literal.create(null, targetType) + case (l, dt) if (l != dt) => Cast(expr, targetType) + case _ => expr + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -321,37 +351,10 @@ object TypeCoercion { case p @ Equality(left @ TimestampType(), right @ StringType()) => p.makeCopy(Array(left, Cast(right, TimestampType))) - // We should cast all relative timestamp/date/string comparison into string comparisons - // This behaves as a user would expect because timestamp strings sort lexicographically. - // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true - case p @ BinaryComparison(left @ StringType(), right @ DateType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - - // Comparisons between dates and timestamps. - case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - - // Checking NullType - case p @ BinaryComparison(left @ StringType(), right @ NullType()) => - p.makeCopy(Array(left, Literal.create(null, StringType))) - case p @ BinaryComparison(left @ NullType(), right @ StringType()) => - p.makeCopy(Array(Literal.create(null, StringType), right)) - - // When compare string with atomic type, case string to that type. - case p @ BinaryComparison(left @ StringType(), right @ AtomicType()) - if right.dataType != StringType => - p.makeCopy(Array(Cast(left, right.dataType), right)) - case p @ BinaryComparison(left @ AtomicType(), right @ StringType()) - if left.dataType != StringType => - p.makeCopy(Array(left, Cast(right, left.dataType))) + case p @ BinaryComparison(left, right) + if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) @@ -365,17 +368,72 @@ object TypeCoercion { } /** - * Convert the value and in list expressions to the common operator type - * by looking at all the argument types and finding the closest one that - * all the arguments can be cast to. When no common operator type is found - * the original expression will be returned and an Analysis Exception will - * be raised at type checking phase. + * Handles type coercion for both IN expression with subquery and IN + * expressions without subquery. + * 1. In the first case, find the common type by comparing the left hand side (LHS) + * expression types against corresponding right hand side (RHS) expression derived + * from the subquery expression's plan output. Inject appropriate casts in the + * LHS and RHS side of IN expression. + * + * 2. In the second case, convert the value and in list expressions to the + * common operator type by looking at all the argument types and finding + * the closest one that all the arguments can be cast to. When no common + * operator type is found the original expression will be returned and an + * Analysis Exception will be raised at the type checking phase. */ object InConversion extends Rule[LogicalPlan] { + private def flattenExpr(expr: Expression): Seq[Expression] = { + expr match { + // Multi columns in IN clause is represented as a CreateNamedStruct. + // flatten the named struct to get the list of expressions. + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Handle type casting required between value expression and subquery output + // in IN subquery. + case i @ In(a, Seq(ListQuery(sub, children, exprId))) + if !i.resolved && flattenExpr(a).length == sub.output.length => + // LHS is the value expression of IN subquery. + val lhs = flattenExpr(a) + + // RHS is the subquery output. + val rhs = sub.output + + val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => + findCommonTypeForBinaryComparison(l.dataType, r.dataType) + .orElse(findTightestCommonType(l.dataType, r.dataType)) + } + + // The number of columns/expressions must match between LHS and RHS of an + // IN subquery expression. + if (commonTypes.length == lhs.length) { + val castedRhs = rhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + val castedLhs = lhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Cast(e, dt) + case (e, _) => e + } + + // Before constructing the In expression, wrap the multi values in LHS + // in a CreatedNamedStruct. + val newLhs = castedLhs match { + case Seq(lhs) => lhs + case _ => CreateStruct(castedLhs) + } + + In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) + } else { + i + } + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac56ff13fa..e5d1a1e299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -123,19 +123,44 @@ case class Not(child: Expression) */ @ExpressionDescription( usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.") -case class In(value: Expression, list: Seq[Expression]) extends Predicate - with ImplicitCastInputTypes { +case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + override def checkInputDataTypes(): TypeCheckResult = { + list match { + case ListQuery(sub, _, _) :: Nil => + val valExprs = value match { + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } - override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } - override def checkInputDataTypes(): TypeCheckResult = { - if (list.exists(l => l.dataType != value.dataType)) { - TypeCheckResult.TypeCheckFailure( - "Arguments must be same type") - } else { - TypeCheckResult.TypeCheckSuccess + if (mismatchedColumns.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } + case _ => + if (list.exists(l => l.dataType != value.dataType)) { + TypeCheckResult.TypeCheckFailure("Arguments must be same type") + } else { + TypeCheckResult.TypeCheckSuccess + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e2e7d98e33..ad11700fa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.types._ /** @@ -40,19 +43,184 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { /** * A base interface for expressions that contain a [[LogicalPlan]]. */ -abstract class SubqueryExpression extends PlanExpression[LogicalPlan] { +abstract class SubqueryExpression( + plan: LogicalPlan, + children: Seq[Expression], + exprId: ExprId) extends PlanExpression[LogicalPlan] { + + override lazy val resolved: Boolean = childrenResolved && plan.resolved + override lazy val references: AttributeSet = + if (plan.resolved) super.references -- plan.outputSet else super.references override def withNewPlan(plan: LogicalPlan): SubqueryExpression + override def semanticEquals(o: Expression): Boolean = o match { + case p: SubqueryExpression => + this.getClass.getName.equals(p.getClass.getName) && plan.sameResult(p.plan) && + children.length == p.children.length && + children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) + case _ => false + } } object SubqueryExpression { + /** + * Returns true when an expression contains an IN or EXISTS subquery and false otherwise. + */ + def hasInOrExistsSubquery(e: Expression): Boolean = { + e.find { + case _: ListQuery | _: Exists => true + case _ => false + }.isDefined + } + + /** + * Returns true when an expression contains a subquery that has outer reference(s). The outer + * reference attributes are kept as children of subquery expression by + * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]] + */ def hasCorrelatedSubquery(e: Expression): Boolean = { e.find { - case e: SubqueryExpression if e.children.nonEmpty => true + case s: SubqueryExpression => s.children.nonEmpty case _ => false }.isDefined } } +object SubExprUtils extends PredicateHelper { + /** + * Returns true when an expression contains correlated predicates i.e outer references and + * returns false otherwise. + */ + def containsOuter(e: Expression): Boolean = { + e.find(_.isInstanceOf[OuterReference]).isDefined + } + + /** + * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could + * turn the null-aware predicate into not-null-aware predicate. + */ + def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { + splitConjunctivePredicates(condition).exists { + case _: Exists | Not(_: Exists) | In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => + false + case e => e.find { x => + x.isInstanceOf[Not] && e.find { + case In(_, Seq(_: ListQuery)) => true + case _ => false + }.isDefined + }.isDefined + } + + } + + /** + * Returns an expression after removing the OuterReference shell. + */ + def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r } + + /** + * Returns the list of expressions after removing the OuterReference shell from each of + * the expression. + */ + def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference) + + /** + * Returns the logical plan after removing the OuterReference shell from all the expressions + * of the input logical plan. + */ + def stripOuterReferences(p: LogicalPlan): LogicalPlan = { + p.transformAllExpressions { + case OuterReference(a) => a + } + } + + /** + * Given a logical plan, returns TRUE if it has an outer reference and false otherwise. + */ + def hasOuterReferences(plan: LogicalPlan): Boolean = { + plan.find { + case f: Filter => containsOuter(f.condition) + case other => false + }.isDefined + } + + /** + * Given a list of expressions, returns the expressions which have outer references. Aggregate + * expressions are treated in a special way. If the children of aggregate expression contains an + * outer reference, then the entire aggregate expression is marked as an outer reference. + * Example (SQL): + * {{{ + * SELECT a FROM l GROUP by 1 HAVING EXISTS (SELECT 1 FROM r WHERE d < min(b)) + * }}} + * In the above case, we want to mark the entire min(b) as an outer reference + * OuterReference(min(b)) instead of min(OuterReference(b)). + * TODO: Currently we don't allow deep correlation. Also, we don't allow mixing of + * outer references and local references under an aggregate expression. + * For example (SQL): + * {{{ + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a) + max(p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + sq.c) > 1)) + * }}} + * The code below needs to change when we support the above cases. + */ + def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = { + val outerExpressions = ArrayBuffer.empty[Expression] + conditions foreach { expr => + expr transformDown { + case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => + val newExpr = stripOuterReference(a) + outerExpressions += newExpr + newExpr + case OuterReference(e) => + outerExpressions += e + e + } + } + outerExpressions + } + + /** + * Returns all the expressions that have outer references from a logical plan. Currently only + * Filter operator can host outer references. + */ + def getOuterReferences(plan: LogicalPlan): Seq[Expression] = { + val conditions = plan.collect { case Filter(cond, _) => cond } + getOuterReferences(conditions) + } + + /** + * Returns the correlated predicates from a logical plan. The OuterReference wrapper + * is removed before returning the predicate to the caller. + */ + def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = { + val conditions = plan.collect { case Filter(cond, _) => cond } + conditions.flatMap { e => + val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter) + stripOuterReferences(correlated) match { + case Nil => None + case xs => xs + } + } + } +} + /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. @@ -63,14 +231,8 @@ case class ScalarSubquery( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved: Boolean = childrenResolved && plan.resolved - override lazy val references: AttributeSet = { - if (plan.resolved) super.references -- plan.outputSet - else super.references - } + extends SubqueryExpression(plan, children, exprId) with Unevaluable { override def dataType: DataType = plan.schema.fields.head.dataType - override def foldable: Boolean = false override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" @@ -79,60 +241,13 @@ case class ScalarSubquery( object ScalarSubquery { def hasCorrelatedScalarSubquery(e: Expression): Boolean = { e.find { - case e: ScalarSubquery if e.children.nonEmpty => true + case s: ScalarSubquery => s.children.nonEmpty case _ => false }.isDefined } } /** - * A predicate subquery checks the existence of a value in a sub-query. We currently only allow - * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will - * be rewritten into a left semi/anti join during analysis. - */ -case class PredicateSubquery( - plan: LogicalPlan, - children: Seq[Expression] = Seq.empty, - nullAware: Boolean = false, - exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = childrenResolved && plan.resolved - override lazy val references: AttributeSet = super.references -- plan.outputSet - override def nullable: Boolean = nullAware - override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan) - override def semanticEquals(o: Expression): Boolean = o match { - case p: PredicateSubquery => - plan.sameResult(p.plan) && nullAware == p.nullAware && - children.length == p.children.length && - children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) - case _ => false - } - override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" -} - -object PredicateSubquery { - def hasPredicateSubquery(e: Expression): Boolean = { - e.find { - case _: PredicateSubquery | _: ListQuery | _: Exists => true - case _ => false - }.isDefined - } - - /** - * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could - * turn the null-aware predicate into not-null-aware predicate. - */ - def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { - e.find{ x => - x.isInstanceOf[Not] && e.find { - case p: PredicateSubquery => p.nullAware - case _ => false - }.isDefined - }.isDefined - } -} - -/** * A [[ListQuery]] expression defines the query which we want to search in an IN subquery * expression. It should and can only be used in conjunction with an IN expression. * @@ -144,18 +259,20 @@ object PredicateSubquery { * FROM b) * }}} */ -case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty - override def dataType: DataType = ArrayType(NullType) +case class ListQuery( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Unevaluable { + override def dataType: DataType = plan.schema.fields.head.dataType override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) - override def toString: String = s"list#${exprId.id}" + override def toString: String = s"list#${exprId.id} $conditionString" } /** * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * * For example (SQL): * {{{ * SELECT * @@ -165,11 +282,12 @@ case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExpr * WHERE b.id = a.id) * }}} */ -case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty +case class Exists( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) - override def toString: String = s"exists#${exprId.id}" + override def toString: String = s"exists#${exprId.id} $conditionString" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index caafa1c134..e9dbded3d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,6 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Pullup Correlated Expressions", Once, + PullupCorrelatedPredicates) :: Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, @@ -885,7 +887,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = { val attributes = plan.outputSet val matched = condition.find { - case PredicateSubquery(p, _, _, _) => p.outputSet.intersect(attributes).nonEmpty + case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty case _ => false } matched.isEmpty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index fb7ce6aece..ba3fd1d5f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -41,10 +42,17 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { + private def getValueExpression(e: Expression): Seq[Expression] = { + e match { + case cns : CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Filter(condition, child) => val (withSubquery, withoutSubquery) = - splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery) + splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery) // Construct the pruned filter condition. val newFilter: LogicalPlan = withoutSubquery match { @@ -54,20 +62,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, PredicateSubquery(sub, conditions, _, _)) => + case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, false, _))) => + case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, true, _))) => + case (p, In(value, Seq(ListQuery(sub, conditions, _)))) => + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + Join(outerPlan, sub, LeftSemi, joinCond) + case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: // (a1,b1,...) = (a2,b2,...) @@ -83,11 +96,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } /** - * Given a predicate expression and an input plan, it rewrites - * any embedded existential sub-query into an existential join. - * It returns the rewritten expression together with the updated plan. - * Currently, it does not support null-aware joins. Embedded NOT IN predicates - * are blocked in the Analyzer. + * Given a predicate expression and an input plan, it rewrites any embedded existential sub-query + * into an existential join. It returns the rewritten expression together with the updated plan. + * Currently, it does not support NOT IN nested inside a NOT expression. This case is blocked in + * the Analyzer. */ private def rewriteExistentialExpr( exprs: Seq[Expression], @@ -95,17 +107,138 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { var newPlan = plan val newExprs = exprs.map { e => e transformUp { - case PredicateSubquery(sub, conditions, nullAware, _) => - // TODO: support null-aware join + case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists - } + case In(value, Seq(ListQuery(sub, conditions, _))) => + val exists = AttributeReference("exists", BooleanType, nullable = false)() + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val newConditions = (inConditions ++ conditions).reduceLeftOption(And) + newPlan = Join(newPlan, sub, ExistenceJoin(exists), newConditions) + exists + } } (newExprs.reduceOption(And), newPlan) } } + /** + * Pull out all (outer) correlated predicates from a given subquery. This method removes the + * correlated predicates from subquery [[Filter]]s and adds the references of these predicates + * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to + * be able to evaluate the predicates at the top level. + * + * TODO: Look to merge this rule with RewritePredicateSubquery. + */ +object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper { + /** + * Returns the correlated predicates and a updated plan that removes the outer references. + */ + private def pullOutCorrelatedPredicates( + sub: LogicalPlan, + outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + + /** Determine which correlated predicate references are missing from this plan. */ + def missingReferences(p: LogicalPlan): AttributeSet = { + val localPredicateReferences = p.collect(predicateMap) + .flatten + .map(_.references) + .reduceOption(_ ++ _) + .getOrElse(AttributeSet.empty) + localPredicateReferences -- p.outputSet + } + + // Simplify the predicates before pulling them out. + val transformed = BooleanSimplification(sub) transformUp { + case f @ Filter(cond, child) => + val (correlated, local) = + splitConjunctivePredicates(cond).partition(containsOuter) + + // Rewrite the filter without the correlated predicates if any. + correlated match { + case Nil => f + case xs if local.nonEmpty => + val newFilter = Filter(local.reduce(And), child) + predicateMap += newFilter -> xs + newFilter + case xs => + predicateMap += child -> xs + child + } + case p @ Project(expressions, child) => + val referencesToAdd = missingReferences(p) + if (referencesToAdd.nonEmpty) { + Project(expressions ++ referencesToAdd, child) + } else { + p + } + case a @ Aggregate(grouping, expressions, child) => + val referencesToAdd = missingReferences(a) + if (referencesToAdd.nonEmpty) { + Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) + } else { + a + } + case p => + p + } + + // Make sure the inner and the outer query attributes do not collide. + // In case of a collision, change the subquery plan's output to use + // different attribute by creating alias(s). + val baseConditions = predicateMap.values.flatten.toSeq + val (newPlan, newCond) = if (outer.nonEmpty) { + val outputSet = outer.map(_.outputSet).reduce(_ ++ _) + val duplicates = transformed.outputSet.intersect(outputSet) + val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = transformed.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val aliasedProjection = Project(aliasedExpressions, transformed) + val aliasedConditions = baseConditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (transformed, baseConditions) + } + (plan, stripOuterReferences(deDuplicatedConditions)) + } else { + (transformed, stripOuterReferences(baseConditions)) + } + (newPlan, newCond) + } + + private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = { + plan transformExpressions { + case ScalarSubquery(sub, children, exprId) if children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ScalarSubquery(newPlan, newCond, exprId) + case Exists(sub, children, exprId) if children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + Exists(newPlan, newCond, exprId) + case ListQuery(sub, _, exprId) => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ListQuery(newPlan, newCond, exprId) + } + } + + /** + * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case f @ Filter(_, a: Aggregate) => + rewriteSubQueries(f, Seq(a, a.child)) + // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. + case q: UnaryNode => + rewriteSubQueries(q, q.children) + } +} /** * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index c5e877d128..d2ebca5a83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -530,7 +530,7 @@ class AnalysisErrorSuite extends AnalysisTest { Exists( Join( LocalRelation(b), - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LeftOuter, Option(EqualTo(b, c)))), LocalRelation(a)) @@ -539,7 +539,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Filter( Exists( Join( - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LocalRelation(b), RightOuter, Option(EqualTo(b, c)))), @@ -547,14 +547,15 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) val plan3 = Filter( - Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), + Exists(Union(LocalRelation(b), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) val plan4 = Filter( Exists( Limit(1, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) ), LocalRelation(a)) assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) @@ -562,7 +563,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 4aafb2b83f..5569312143 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1) + val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.ResolveSubquery(expr) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e9b7a0c6ad..5eb31413ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -43,8 +43,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { e.copy(exprId = ExprId(0)) case l: ListQuery => l.copy(exprId = ExprId(0)) - case p: PredicateSubquery => - p.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => |