diff options
7 files changed, 173 insertions, 49 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8595762988..182e459d8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -862,28 +862,68 @@ class Analyzer( object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { /** - * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a + * Resolve the correlated predicates in the clauses (e.g. WHERE or HAVING) of a * sub-query by using the plan the predicates should be correlated to. */ - private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = { - q transformUp { - case f @ Filter(cond, child) if child.resolved && !f.resolved => - val newCond = resolveExpression(cond, p, throws = false) - if (!cond.fastEquals(newCond)) { - Filter(newCond, child) - } else { - f - } + private def resolveCorrelatedSubquery( + sub: LogicalPlan, outer: LogicalPlan, + aliases: scala.collection.mutable.Map[Attribute, Alias]): LogicalPlan = { + // First resolve as much of the sub-query as possible + val analyzed = execute(sub) + if (analyzed.resolved) { + analyzed + } else { + // Only resolve the lowest plan that is not resolved by outer plan, otherwise it could be + // resolved by itself + val resolvedByOuter = analyzed transformDown { + case q: LogicalPlan if q.childrenResolved && !q.resolved => + q transformExpressions { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { + try { + val outerAttrOpt = outer.resolve(nameParts, resolver) + if (outerAttrOpt.isDefined) { + val outerAttr = outerAttrOpt.get + if (q.inputSet.contains(outerAttr)) { + // Got a conflict, create an alias for the attribute come from outer table + val alias = Alias(outerAttr, outerAttr.toString)() + val attr = alias.toAttribute + aliases += attr -> alias + attr + } else { + outerAttr + } + } else { + u + } + } catch { + case a: AnalysisException => u + } + } + } + } + if (resolvedByOuter fastEquals analyzed) { + analyzed + } else { + resolveCorrelatedSubquery(resolvedByOuter, outer, aliases) + } } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case q: LogicalPlan if q.childrenResolved => - q transformExpressions { + // Only a few unary node (Project/Filter/Aggregate/Having) could have subquery + case q: UnaryNode if q.childrenResolved => + val aliases = scala.collection.mutable.Map[Attribute, Alias]() + val newPlan = q transformExpressions { case e: SubqueryExpression if !e.query.resolved => - // First resolve as much of the sub-query as possible. After that we use the children of - // this plan to resolve the remaining correlated predicates. - e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates)) + e.withNewPlan(resolveCorrelatedSubquery(e.query, q.child, aliases)) + } + if (aliases.nonEmpty) { + val projs = q.child.output ++ aliases.values + Project(q.child.output, + newPlan.withNewChildren(Seq(Project(projs, q.child)))) + } else { + newPlan } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 45e4d535c1..a50b9a1e1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -113,16 +113,21 @@ trait CheckAnalysis extends PredicateHelper { case f @ Filter(condition, child) => // Make sure that no correlated reference is below Aggregates, Outer Joins and on the // right hand side of Unions. - lazy val attributes = child.outputSet + lazy val outerAttributes = child.outputSet def failOnCorrelatedReference( - p: LogicalPlan, - message: String): Unit = p.transformAllExpressions { - case e: NamedExpression if attributes.contains(e) => - failAnalysis(s"Accessing outer query column is not allowed in $message: $e") + plan: LogicalPlan, + message: String): Unit = plan foreach { + case p => + lazy val inputs = p.inputSet + p.transformExpressions { + case e: AttributeReference + if !inputs.contains(e) && outerAttributes.contains(e) => + failAnalysis(s"Accessing outer query column is not allowed in $message: $e") + } } def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach { case a @ Aggregate(_, _, source) => - failOnCorrelatedReference(source, "an AGGREATE") + failOnCorrelatedReference(source, "an AGGREGATE") case j @ Join(left, _, RightOuter, _) => failOnCorrelatedReference(left, "a RIGHT OUTER JOIN") case j @ Join(_, right, jt, _) if jt != Inner => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 5323b79c57..0306afb0d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -376,6 +376,31 @@ object HiveTypeCoercion { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } + + case InSubQuery(struct: CreateStruct, subquery, exprId) + if struct.children.zip(subquery.output).exists(x => x._1.dataType != x._2.dataType) => + val widerTypes: Seq[Option[DataType]] = struct.children.zip(subquery.output).map { + case (l, r) => findWiderTypeForTwo(l.dataType, r.dataType) + } + val newStruct = struct.withNewChildren(struct.children.zip(widerTypes).map { + case (e, Some(t)) => Cast(e, t) + case (e, _) => e + }) + val newSubquery = Project(subquery.output.zip(widerTypes).map { + case (a, Some(t)) => Alias(Cast(a, t), a.toString)() + case (a, _) => a + }, subquery) + InSubQuery(newStruct, newSubquery, exprId) + + case sub @ InSubQuery(expr, subquery, exprId) + if expr.dataType != subquery.output.head.dataType => + findWiderTypeForTwo(expr.dataType, subquery.output.head.dataType) match { + case Some(t) => + val attr = subquery.output.head + val proj = Seq(Alias(Cast(attr, t), attr.toString)()) + InSubQuery(Cast(expr, t), Project(proj, subquery), exprId) + case _ => sub + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index cbee0e61f7..1993bd2587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -87,7 +87,6 @@ case class ScalarSubquery( */ abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate { override def nullable: Boolean = false - override def plan: LogicalPlan = SubqueryAlias(prettyName, query) } object PredicateSubquery { @@ -105,10 +104,14 @@ object PredicateSubquery { * FROM b) * }}} */ -case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery { +case class InSubQuery( + value: Expression, + query: LogicalPlan, + exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery { override def children: Seq[Expression] = value :: Nil override lazy val resolved: Boolean = value.resolved && query.resolved - override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan) + override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan, exprId) + override def plan: LogicalPlan = SubqueryAlias(s"subquery#${exprId.id}", query) /** * The unwrapped value side expressions. @@ -124,7 +127,7 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu override def checkInputDataTypes(): TypeCheckResult = { // Check the number of arguments. if (expressions.length != query.output.length) { - TypeCheckResult.TypeCheckFailure( + return TypeCheckResult.TypeCheckFailure( s"The number of fields in the value (${expressions.length}) does not match with " + s"the number of columns in the subquery (${query.output.length})") } @@ -132,14 +135,16 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu // Check the argument types. expressions.zip(query.output).zipWithIndex.foreach { case ((e, a), i) if e.dataType != a.dataType => - TypeCheckResult.TypeCheckFailure( - s"The data type of value[$i](${e.dataType}) does not match " + + return TypeCheckResult.TypeCheckFailure( + s"The data type of value[$i] (${e.dataType}) does not match " + s"subquery column '${a.name}' (${a.dataType}).") case _ => } TypeCheckResult.TypeCheckSuccess } + + override def toString: String = s"$value IN subquery#${exprId.id}" } /** @@ -153,7 +158,11 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu * WHERE b.id = a.id) * }}} */ -case class Exists(query: LogicalPlan) extends PredicateSubquery { +case class Exists( + query: LogicalPlan, + exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery { override def children: Seq[Expression] = Nil - override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan) + override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan, exprId) + override def plan: LogicalPlan = SubqueryAlias(toString, query) + override def toString: String = s"exists#${exprId.id}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e6d554565d..e974f69ef1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1474,7 +1474,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { case f @ Filter(cond, child) => // Find all correlated predicates. val (correlated, local) = splitConjunctivePredicates(cond).partition { e => - e.references.intersect(references).nonEmpty + (e.references -- child.outputSet).intersect(references).nonEmpty } // Rewrite the filter without the correlated predicates if any. correlated match { @@ -1515,10 +1515,34 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { */ private def pullOutCorrelatedPredicates( in: InSubQuery, - query: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + query: LogicalPlan): (LogicalPlan, LogicalPlan, Seq[Expression]) = { val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query) - val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled) - (resolved, conditions) + // Check whether there is some attributes have same exprId but come from different side + val outerAttributes = AttributeSet(in.expressions.flatMap(_.references)) + if (outerAttributes.intersect(resolved.outputSet).nonEmpty) { + val aliases = mutable.Map[Attribute, Alias]() + val exprs = in.expressions.map { expr => + expr transformUp { + case a: AttributeReference if resolved.outputSet.contains(a) => + val alias = Alias(a, a.toString)() + val attr = alias.toAttribute + aliases += attr -> alias + attr + } + } + val newP = Project(query.output ++ aliases.values, query) + val projection = resolved.output.map { + case a if outerAttributes.contains(a) => Alias(a, a.toString)() + case a => a + } + val subquery = Project(projection, resolved) + val conditions = joinCondition ++ exprs.zip(subquery.output).map(EqualTo.tupled) + (newP, subquery, conditions) + } else { + val conditions = + joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled) + (query, resolved, conditions) + } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -1534,17 +1558,22 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, Exists(sub)) => + case (p, Exists(sub, _)) => val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) Join(p, resolved, LeftSemi, conditions.reduceOption(And)) - case (p, Not(Exists(sub))) => + case (p, Not(Exists(sub, _))) => val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) Join(p, resolved, LeftAnti, conditions.reduceOption(And)) case (p, in: InSubQuery) => - val (resolved, conditions) = pullOutCorrelatedPredicates(in, p) - Join(p, resolved, LeftSemi, conditions.reduceOption(And)) + val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p) + if (newP fastEquals p) { + Join(p, resolved, LeftSemi, conditions.reduceOption(And)) + } else { + Project(p.output, + Join(newP, resolved, LeftSemi, conditions.reduceOption(And))) + } case (p, Not(in: InSubQuery)) => - val (resolved, conditions) = pullOutCorrelatedPredicates(in, p) + val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p) // This is a NULL-aware (left) anti join (NAAJ). // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -1553,7 +1582,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS // if performance matters to you. - Join(p, resolved, LeftAnti, Option(Or(anyNull, condition))) + if (newP fastEquals p) { + Join(p, resolved, LeftAnti, Option(Or(anyNull, condition))) + } else { + Project(p.output, + Join(newP, resolved, LeftAnti, Option(Or(anyNull, condition)))) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 7191936699..f5439d70ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -35,6 +35,10 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { plan transformAllExpressions { case s: ScalarSubquery => ScalarSubquery(s.query, ExprId(0)) + case s: InSubQuery => + InSubQuery(s.value, s.query, ExprId(0)) + case e: Exists => + Exists(e.query, ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index d69ef08735..d182495757 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -125,21 +125,21 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("EXISTS predicate subquery") { checkAnswer( - sql("select * from l where exists(select * from r where l.a = r.c)"), + sql("select * from l where exists (select * from r where l.a = r.c)"), Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) checkAnswer( - sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"), + sql("select * from l where exists (select * from r where l.a = r.c) and l.a <= 2"), Row(2, 1.0) :: Row(2, 1.0) :: Nil) } test("NOT EXISTS predicate subquery") { checkAnswer( - sql("select * from l where not exists(select * from r where l.a = r.c)"), + sql("select * from l where not exists (select * from r where l.a = r.c)"), Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil) checkAnswer( - sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"), + sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)"), Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) } @@ -160,20 +160,20 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("NOT IN predicate subquery") { checkAnswer( - sql("select * from l where a not in(select c from r)"), + sql("select * from l where a not in (select c from r)"), Nil) checkAnswer( - sql("select * from l where a not in(select c from r where c is not null)"), + sql("select * from l where a not in (select c from r where c is not null)"), Row(1, 2.0) :: Row(1, 2.0) :: Nil) checkAnswer( - sql("select * from l where a not in(select c from t where b < d)"), + sql("select * from l where a not in (select c from t where b < d)"), Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil) // Empty sub-query checkAnswer( - sql("select * from l where a not in(select c from r where c > 10 and b < d)"), + sql("select * from l where a not in (select c from r where c > 10 and b < d)"), Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) @@ -181,11 +181,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("complex IN predicate subquery") { checkAnswer( - sql("select * from l where (a, b) not in(select c, d from r)"), + sql("select * from l where (a, b) not in (select c, d from r)"), Nil) checkAnswer( - sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"), + sql("select * from l where (a, b) not in (select c, d from t) and (a + b) is not null"), Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil) } + + test("same column in subquery and outer table") { + checkAnswer( + sql("select a from l l1 where a in (select a from l where a < 3 group by a)"), + Row(1) :: Row(1) :: Row(2) :: Row(2) :: Nil + ) + } } |