aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala70
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala54
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala27
7 files changed, 173 insertions, 49 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8595762988..182e459d8f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -862,28 +862,68 @@ class Analyzer(
object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
/**
- * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a
+ * Resolve the correlated predicates in the clauses (e.g. WHERE or HAVING) of a
* sub-query by using the plan the predicates should be correlated to.
*/
- private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = {
- q transformUp {
- case f @ Filter(cond, child) if child.resolved && !f.resolved =>
- val newCond = resolveExpression(cond, p, throws = false)
- if (!cond.fastEquals(newCond)) {
- Filter(newCond, child)
- } else {
- f
- }
+ private def resolveCorrelatedSubquery(
+ sub: LogicalPlan, outer: LogicalPlan,
+ aliases: scala.collection.mutable.Map[Attribute, Alias]): LogicalPlan = {
+ // First resolve as much of the sub-query as possible
+ val analyzed = execute(sub)
+ if (analyzed.resolved) {
+ analyzed
+ } else {
+ // Only resolve the lowest plan that is not resolved by outer plan, otherwise it could be
+ // resolved by itself
+ val resolvedByOuter = analyzed transformDown {
+ case q: LogicalPlan if q.childrenResolved && !q.resolved =>
+ q transformExpressions {
+ case u @ UnresolvedAttribute(nameParts) =>
+ withPosition(u) {
+ try {
+ val outerAttrOpt = outer.resolve(nameParts, resolver)
+ if (outerAttrOpt.isDefined) {
+ val outerAttr = outerAttrOpt.get
+ if (q.inputSet.contains(outerAttr)) {
+ // Got a conflict, create an alias for the attribute come from outer table
+ val alias = Alias(outerAttr, outerAttr.toString)()
+ val attr = alias.toAttribute
+ aliases += attr -> alias
+ attr
+ } else {
+ outerAttr
+ }
+ } else {
+ u
+ }
+ } catch {
+ case a: AnalysisException => u
+ }
+ }
+ }
+ }
+ if (resolvedByOuter fastEquals analyzed) {
+ analyzed
+ } else {
+ resolveCorrelatedSubquery(resolvedByOuter, outer, aliases)
+ }
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case q: LogicalPlan if q.childrenResolved =>
- q transformExpressions {
+ // Only a few unary node (Project/Filter/Aggregate/Having) could have subquery
+ case q: UnaryNode if q.childrenResolved =>
+ val aliases = scala.collection.mutable.Map[Attribute, Alias]()
+ val newPlan = q transformExpressions {
case e: SubqueryExpression if !e.query.resolved =>
- // First resolve as much of the sub-query as possible. After that we use the children of
- // this plan to resolve the remaining correlated predicates.
- e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates))
+ e.withNewPlan(resolveCorrelatedSubquery(e.query, q.child, aliases))
+ }
+ if (aliases.nonEmpty) {
+ val projs = q.child.output ++ aliases.values
+ Project(q.child.output,
+ newPlan.withNewChildren(Seq(Project(projs, q.child))))
+ } else {
+ newPlan
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 45e4d535c1..a50b9a1e1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -113,16 +113,21 @@ trait CheckAnalysis extends PredicateHelper {
case f @ Filter(condition, child) =>
// Make sure that no correlated reference is below Aggregates, Outer Joins and on the
// right hand side of Unions.
- lazy val attributes = child.outputSet
+ lazy val outerAttributes = child.outputSet
def failOnCorrelatedReference(
- p: LogicalPlan,
- message: String): Unit = p.transformAllExpressions {
- case e: NamedExpression if attributes.contains(e) =>
- failAnalysis(s"Accessing outer query column is not allowed in $message: $e")
+ plan: LogicalPlan,
+ message: String): Unit = plan foreach {
+ case p =>
+ lazy val inputs = p.inputSet
+ p.transformExpressions {
+ case e: AttributeReference
+ if !inputs.contains(e) && outerAttributes.contains(e) =>
+ failAnalysis(s"Accessing outer query column is not allowed in $message: $e")
+ }
}
def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach {
case a @ Aggregate(_, _, source) =>
- failOnCorrelatedReference(source, "an AGGREATE")
+ failOnCorrelatedReference(source, "an AGGREGATE")
case j @ Join(left, _, RightOuter, _) =>
failOnCorrelatedReference(left, "a RIGHT OUTER JOIN")
case j @ Join(_, right, jt, _) if jt != Inner =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 5323b79c57..0306afb0d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -376,6 +376,31 @@ object HiveTypeCoercion {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
}
+
+ case InSubQuery(struct: CreateStruct, subquery, exprId)
+ if struct.children.zip(subquery.output).exists(x => x._1.dataType != x._2.dataType) =>
+ val widerTypes: Seq[Option[DataType]] = struct.children.zip(subquery.output).map {
+ case (l, r) => findWiderTypeForTwo(l.dataType, r.dataType)
+ }
+ val newStruct = struct.withNewChildren(struct.children.zip(widerTypes).map {
+ case (e, Some(t)) => Cast(e, t)
+ case (e, _) => e
+ })
+ val newSubquery = Project(subquery.output.zip(widerTypes).map {
+ case (a, Some(t)) => Alias(Cast(a, t), a.toString)()
+ case (a, _) => a
+ }, subquery)
+ InSubQuery(newStruct, newSubquery, exprId)
+
+ case sub @ InSubQuery(expr, subquery, exprId)
+ if expr.dataType != subquery.output.head.dataType =>
+ findWiderTypeForTwo(expr.dataType, subquery.output.head.dataType) match {
+ case Some(t) =>
+ val attr = subquery.output.head
+ val proj = Seq(Alias(Cast(attr, t), attr.toString)())
+ InSubQuery(Cast(expr, t), Project(proj, subquery), exprId)
+ case _ => sub
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index cbee0e61f7..1993bd2587 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -87,7 +87,6 @@ case class ScalarSubquery(
*/
abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate {
override def nullable: Boolean = false
- override def plan: LogicalPlan = SubqueryAlias(prettyName, query)
}
object PredicateSubquery {
@@ -105,10 +104,14 @@ object PredicateSubquery {
* FROM b)
* }}}
*/
-case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery {
+case class InSubQuery(
+ value: Expression,
+ query: LogicalPlan,
+ exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery {
override def children: Seq[Expression] = value :: Nil
override lazy val resolved: Boolean = value.resolved && query.resolved
- override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan)
+ override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan, exprId)
+ override def plan: LogicalPlan = SubqueryAlias(s"subquery#${exprId.id}", query)
/**
* The unwrapped value side expressions.
@@ -124,7 +127,7 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu
override def checkInputDataTypes(): TypeCheckResult = {
// Check the number of arguments.
if (expressions.length != query.output.length) {
- TypeCheckResult.TypeCheckFailure(
+ return TypeCheckResult.TypeCheckFailure(
s"The number of fields in the value (${expressions.length}) does not match with " +
s"the number of columns in the subquery (${query.output.length})")
}
@@ -132,14 +135,16 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu
// Check the argument types.
expressions.zip(query.output).zipWithIndex.foreach {
case ((e, a), i) if e.dataType != a.dataType =>
- TypeCheckResult.TypeCheckFailure(
- s"The data type of value[$i](${e.dataType}) does not match " +
+ return TypeCheckResult.TypeCheckFailure(
+ s"The data type of value[$i] (${e.dataType}) does not match " +
s"subquery column '${a.name}' (${a.dataType}).")
case _ =>
}
TypeCheckResult.TypeCheckSuccess
}
+
+ override def toString: String = s"$value IN subquery#${exprId.id}"
}
/**
@@ -153,7 +158,11 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu
* WHERE b.id = a.id)
* }}}
*/
-case class Exists(query: LogicalPlan) extends PredicateSubquery {
+case class Exists(
+ query: LogicalPlan,
+ exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery {
override def children: Seq[Expression] = Nil
- override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan)
+ override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan, exprId)
+ override def plan: LogicalPlan = SubqueryAlias(toString, query)
+ override def toString: String = s"exists#${exprId.id}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e6d554565d..e974f69ef1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1474,7 +1474,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
case f @ Filter(cond, child) =>
// Find all correlated predicates.
val (correlated, local) = splitConjunctivePredicates(cond).partition { e =>
- e.references.intersect(references).nonEmpty
+ (e.references -- child.outputSet).intersect(references).nonEmpty
}
// Rewrite the filter without the correlated predicates if any.
correlated match {
@@ -1515,10 +1515,34 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
*/
private def pullOutCorrelatedPredicates(
in: InSubQuery,
- query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+ query: LogicalPlan): (LogicalPlan, LogicalPlan, Seq[Expression]) = {
val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query)
- val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled)
- (resolved, conditions)
+ // Check whether there is some attributes have same exprId but come from different side
+ val outerAttributes = AttributeSet(in.expressions.flatMap(_.references))
+ if (outerAttributes.intersect(resolved.outputSet).nonEmpty) {
+ val aliases = mutable.Map[Attribute, Alias]()
+ val exprs = in.expressions.map { expr =>
+ expr transformUp {
+ case a: AttributeReference if resolved.outputSet.contains(a) =>
+ val alias = Alias(a, a.toString)()
+ val attr = alias.toAttribute
+ aliases += attr -> alias
+ attr
+ }
+ }
+ val newP = Project(query.output ++ aliases.values, query)
+ val projection = resolved.output.map {
+ case a if outerAttributes.contains(a) => Alias(a, a.toString)()
+ case a => a
+ }
+ val subquery = Project(projection, resolved)
+ val conditions = joinCondition ++ exprs.zip(subquery.output).map(EqualTo.tupled)
+ (newP, subquery, conditions)
+ } else {
+ val conditions =
+ joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled)
+ (query, resolved, conditions)
+ }
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -1534,17 +1558,22 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
- case (p, Exists(sub)) =>
+ case (p, Exists(sub, _)) =>
val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
Join(p, resolved, LeftSemi, conditions.reduceOption(And))
- case (p, Not(Exists(sub))) =>
+ case (p, Not(Exists(sub, _))) =>
val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
Join(p, resolved, LeftAnti, conditions.reduceOption(And))
case (p, in: InSubQuery) =>
- val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
- Join(p, resolved, LeftSemi, conditions.reduceOption(And))
+ val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+ if (newP fastEquals p) {
+ Join(p, resolved, LeftSemi, conditions.reduceOption(And))
+ } else {
+ Project(p.output,
+ Join(newP, resolved, LeftSemi, conditions.reduceOption(And)))
+ }
case (p, Not(in: InSubQuery)) =>
- val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+ val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p)
// This is a NULL-aware (left) anti join (NAAJ).
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
@@ -1553,7 +1582,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
// if performance matters to you.
- Join(p, resolved, LeftAnti, Option(Or(anyNull, condition)))
+ if (newP fastEquals p) {
+ Join(p, resolved, LeftAnti, Option(Or(anyNull, condition)))
+ } else {
+ Project(p.output,
+ Join(newP, resolved, LeftAnti, Option(Or(anyNull, condition))))
+ }
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 7191936699..f5439d70ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -35,6 +35,10 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
plan transformAllExpressions {
case s: ScalarSubquery =>
ScalarSubquery(s.query, ExprId(0))
+ case s: InSubQuery =>
+ InSubQuery(s.value, s.query, ExprId(0))
+ case e: Exists =>
+ Exists(e.query, ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index d69ef08735..d182495757 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -125,21 +125,21 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("EXISTS predicate subquery") {
checkAnswer(
- sql("select * from l where exists(select * from r where l.a = r.c)"),
+ sql("select * from l where exists (select * from r where l.a = r.c)"),
Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
checkAnswer(
- sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"),
+ sql("select * from l where exists (select * from r where l.a = r.c) and l.a <= 2"),
Row(2, 1.0) :: Row(2, 1.0) :: Nil)
}
test("NOT EXISTS predicate subquery") {
checkAnswer(
- sql("select * from l where not exists(select * from r where l.a = r.c)"),
+ sql("select * from l where not exists (select * from r where l.a = r.c)"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil)
checkAnswer(
- sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"),
+ sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) ::
Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
}
@@ -160,20 +160,20 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("NOT IN predicate subquery") {
checkAnswer(
- sql("select * from l where a not in(select c from r)"),
+ sql("select * from l where a not in (select c from r)"),
Nil)
checkAnswer(
- sql("select * from l where a not in(select c from r where c is not null)"),
+ sql("select * from l where a not in (select c from r where c is not null)"),
Row(1, 2.0) :: Row(1, 2.0) :: Nil)
checkAnswer(
- sql("select * from l where a not in(select c from t where b < d)"),
+ sql("select * from l where a not in (select c from t where b < d)"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil)
// Empty sub-query
checkAnswer(
- sql("select * from l where a not in(select c from r where c > 10 and b < d)"),
+ sql("select * from l where a not in (select c from r where c > 10 and b < d)"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) ::
Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
@@ -181,11 +181,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("complex IN predicate subquery") {
checkAnswer(
- sql("select * from l where (a, b) not in(select c, d from r)"),
+ sql("select * from l where (a, b) not in (select c, d from r)"),
Nil)
checkAnswer(
- sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"),
+ sql("select * from l where (a, b) not in (select c, d from t) and (a + b) is not null"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil)
}
+
+ test("same column in subquery and outer table") {
+ checkAnswer(
+ sql("select a from l l1 where a in (select a from l where a < 3 group by a)"),
+ Row(1) :: Row(1) :: Row(2) :: Row(2) :: Nil
+ )
+ }
}