aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-22 20:55:41 +0200
committerHerman van Hovell <hvanhovell@questtec.nl>2016-04-22 20:55:41 +0200
commitc417cec067715455c1536d37e0dba97cc8657f7b (patch)
treea89a659b72023065297c284087c1d310577c5697
parentd060da098aa0449f519fb22c3ed8f75f87ba5f12 (diff)
downloadspark-c417cec067715455c1536d37e0dba97cc8657f7b.tar.gz
spark-c417cec067715455c1536d37e0dba97cc8657f7b.tar.bz2
spark-c417cec067715455c1536d37e0dba97cc8657f7b.zip
[SPARK-14763][SQL] fix subquery resolution
## What changes were proposed in this pull request? Currently, a column could be resolved wrongly if there are columns from both outer table and subquery have the same name, we should only resolve the attributes that can't be resolved within subquery. They may have same exprId than other attributes in subquery, so we should create alias for them. Also, the column in IN subquery could have same exprId, we should create alias for them. ## How was this patch tested? Added regression tests. Manually tests TPCDS Q70 and Q95, work well after this patch. Author: Davies Liu <davies@databricks.com> Closes #12539 from davies/fix_subquery.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala70
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala54
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala27
7 files changed, 173 insertions, 49 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8595762988..182e459d8f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -862,28 +862,68 @@ class Analyzer(
object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
/**
- * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a
+ * Resolve the correlated predicates in the clauses (e.g. WHERE or HAVING) of a
* sub-query by using the plan the predicates should be correlated to.
*/
- private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = {
- q transformUp {
- case f @ Filter(cond, child) if child.resolved && !f.resolved =>
- val newCond = resolveExpression(cond, p, throws = false)
- if (!cond.fastEquals(newCond)) {
- Filter(newCond, child)
- } else {
- f
- }
+ private def resolveCorrelatedSubquery(
+ sub: LogicalPlan, outer: LogicalPlan,
+ aliases: scala.collection.mutable.Map[Attribute, Alias]): LogicalPlan = {
+ // First resolve as much of the sub-query as possible
+ val analyzed = execute(sub)
+ if (analyzed.resolved) {
+ analyzed
+ } else {
+ // Only resolve the lowest plan that is not resolved by outer plan, otherwise it could be
+ // resolved by itself
+ val resolvedByOuter = analyzed transformDown {
+ case q: LogicalPlan if q.childrenResolved && !q.resolved =>
+ q transformExpressions {
+ case u @ UnresolvedAttribute(nameParts) =>
+ withPosition(u) {
+ try {
+ val outerAttrOpt = outer.resolve(nameParts, resolver)
+ if (outerAttrOpt.isDefined) {
+ val outerAttr = outerAttrOpt.get
+ if (q.inputSet.contains(outerAttr)) {
+ // Got a conflict, create an alias for the attribute come from outer table
+ val alias = Alias(outerAttr, outerAttr.toString)()
+ val attr = alias.toAttribute
+ aliases += attr -> alias
+ attr
+ } else {
+ outerAttr
+ }
+ } else {
+ u
+ }
+ } catch {
+ case a: AnalysisException => u
+ }
+ }
+ }
+ }
+ if (resolvedByOuter fastEquals analyzed) {
+ analyzed
+ } else {
+ resolveCorrelatedSubquery(resolvedByOuter, outer, aliases)
+ }
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case q: LogicalPlan if q.childrenResolved =>
- q transformExpressions {
+ // Only a few unary node (Project/Filter/Aggregate/Having) could have subquery
+ case q: UnaryNode if q.childrenResolved =>
+ val aliases = scala.collection.mutable.Map[Attribute, Alias]()
+ val newPlan = q transformExpressions {
case e: SubqueryExpression if !e.query.resolved =>
- // First resolve as much of the sub-query as possible. After that we use the children of
- // this plan to resolve the remaining correlated predicates.
- e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates))
+ e.withNewPlan(resolveCorrelatedSubquery(e.query, q.child, aliases))
+ }
+ if (aliases.nonEmpty) {
+ val projs = q.child.output ++ aliases.values
+ Project(q.child.output,
+ newPlan.withNewChildren(Seq(Project(projs, q.child))))
+ } else {
+ newPlan
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 45e4d535c1..a50b9a1e1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -113,16 +113,21 @@ trait CheckAnalysis extends PredicateHelper {
case f @ Filter(condition, child) =>
// Make sure that no correlated reference is below Aggregates, Outer Joins and on the
// right hand side of Unions.
- lazy val attributes = child.outputSet
+ lazy val outerAttributes = child.outputSet
def failOnCorrelatedReference(
- p: LogicalPlan,
- message: String): Unit = p.transformAllExpressions {
- case e: NamedExpression if attributes.contains(e) =>
- failAnalysis(s"Accessing outer query column is not allowed in $message: $e")
+ plan: LogicalPlan,
+ message: String): Unit = plan foreach {
+ case p =>
+ lazy val inputs = p.inputSet
+ p.transformExpressions {
+ case e: AttributeReference
+ if !inputs.contains(e) && outerAttributes.contains(e) =>
+ failAnalysis(s"Accessing outer query column is not allowed in $message: $e")
+ }
}
def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach {
case a @ Aggregate(_, _, source) =>
- failOnCorrelatedReference(source, "an AGGREATE")
+ failOnCorrelatedReference(source, "an AGGREGATE")
case j @ Join(left, _, RightOuter, _) =>
failOnCorrelatedReference(left, "a RIGHT OUTER JOIN")
case j @ Join(_, right, jt, _) if jt != Inner =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 5323b79c57..0306afb0d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -376,6 +376,31 @@ object HiveTypeCoercion {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
}
+
+ case InSubQuery(struct: CreateStruct, subquery, exprId)
+ if struct.children.zip(subquery.output).exists(x => x._1.dataType != x._2.dataType) =>
+ val widerTypes: Seq[Option[DataType]] = struct.children.zip(subquery.output).map {
+ case (l, r) => findWiderTypeForTwo(l.dataType, r.dataType)
+ }
+ val newStruct = struct.withNewChildren(struct.children.zip(widerTypes).map {
+ case (e, Some(t)) => Cast(e, t)
+ case (e, _) => e
+ })
+ val newSubquery = Project(subquery.output.zip(widerTypes).map {
+ case (a, Some(t)) => Alias(Cast(a, t), a.toString)()
+ case (a, _) => a
+ }, subquery)
+ InSubQuery(newStruct, newSubquery, exprId)
+
+ case sub @ InSubQuery(expr, subquery, exprId)
+ if expr.dataType != subquery.output.head.dataType =>
+ findWiderTypeForTwo(expr.dataType, subquery.output.head.dataType) match {
+ case Some(t) =>
+ val attr = subquery.output.head
+ val proj = Seq(Alias(Cast(attr, t), attr.toString)())
+ InSubQuery(Cast(expr, t), Project(proj, subquery), exprId)
+ case _ => sub
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index cbee0e61f7..1993bd2587 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -87,7 +87,6 @@ case class ScalarSubquery(
*/
abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate {
override def nullable: Boolean = false
- override def plan: LogicalPlan = SubqueryAlias(prettyName, query)
}
object PredicateSubquery {
@@ -105,10 +104,14 @@ object PredicateSubquery {
* FROM b)
* }}}
*/
-case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery {
+case class InSubQuery(
+ value: Expression,
+ query: LogicalPlan,
+ exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery {
override def children: Seq[Expression] = value :: Nil
override lazy val resolved: Boolean = value.resolved && query.resolved
- override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan)
+ override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan, exprId)
+ override def plan: LogicalPlan = SubqueryAlias(s"subquery#${exprId.id}", query)
/**
* The unwrapped value side expressions.
@@ -124,7 +127,7 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu
override def checkInputDataTypes(): TypeCheckResult = {
// Check the number of arguments.
if (expressions.length != query.output.length) {
- TypeCheckResult.TypeCheckFailure(
+ return TypeCheckResult.TypeCheckFailure(
s"The number of fields in the value (${expressions.length}) does not match with " +
s"the number of columns in the subquery (${query.output.length})")
}
@@ -132,14 +135,16 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu
// Check the argument types.
expressions.zip(query.output).zipWithIndex.foreach {
case ((e, a), i) if e.dataType != a.dataType =>
- TypeCheckResult.TypeCheckFailure(
- s"The data type of value[$i](${e.dataType}) does not match " +
+ return TypeCheckResult.TypeCheckFailure(
+ s"The data type of value[$i] (${e.dataType}) does not match " +
s"subquery column '${a.name}' (${a.dataType}).")
case _ =>
}
TypeCheckResult.TypeCheckSuccess
}
+
+ override def toString: String = s"$value IN subquery#${exprId.id}"
}
/**
@@ -153,7 +158,11 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu
* WHERE b.id = a.id)
* }}}
*/
-case class Exists(query: LogicalPlan) extends PredicateSubquery {
+case class Exists(
+ query: LogicalPlan,
+ exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery {
override def children: Seq[Expression] = Nil
- override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan)
+ override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan, exprId)
+ override def plan: LogicalPlan = SubqueryAlias(toString, query)
+ override def toString: String = s"exists#${exprId.id}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e6d554565d..e974f69ef1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1474,7 +1474,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
case f @ Filter(cond, child) =>
// Find all correlated predicates.
val (correlated, local) = splitConjunctivePredicates(cond).partition { e =>
- e.references.intersect(references).nonEmpty
+ (e.references -- child.outputSet).intersect(references).nonEmpty
}
// Rewrite the filter without the correlated predicates if any.
correlated match {
@@ -1515,10 +1515,34 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
*/
private def pullOutCorrelatedPredicates(
in: InSubQuery,
- query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+ query: LogicalPlan): (LogicalPlan, LogicalPlan, Seq[Expression]) = {
val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query)
- val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled)
- (resolved, conditions)
+ // Check whether there is some attributes have same exprId but come from different side
+ val outerAttributes = AttributeSet(in.expressions.flatMap(_.references))
+ if (outerAttributes.intersect(resolved.outputSet).nonEmpty) {
+ val aliases = mutable.Map[Attribute, Alias]()
+ val exprs = in.expressions.map { expr =>
+ expr transformUp {
+ case a: AttributeReference if resolved.outputSet.contains(a) =>
+ val alias = Alias(a, a.toString)()
+ val attr = alias.toAttribute
+ aliases += attr -> alias
+ attr
+ }
+ }
+ val newP = Project(query.output ++ aliases.values, query)
+ val projection = resolved.output.map {
+ case a if outerAttributes.contains(a) => Alias(a, a.toString)()
+ case a => a
+ }
+ val subquery = Project(projection, resolved)
+ val conditions = joinCondition ++ exprs.zip(subquery.output).map(EqualTo.tupled)
+ (newP, subquery, conditions)
+ } else {
+ val conditions =
+ joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled)
+ (query, resolved, conditions)
+ }
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -1534,17 +1558,22 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
- case (p, Exists(sub)) =>
+ case (p, Exists(sub, _)) =>
val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
Join(p, resolved, LeftSemi, conditions.reduceOption(And))
- case (p, Not(Exists(sub))) =>
+ case (p, Not(Exists(sub, _))) =>
val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
Join(p, resolved, LeftAnti, conditions.reduceOption(And))
case (p, in: InSubQuery) =>
- val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
- Join(p, resolved, LeftSemi, conditions.reduceOption(And))
+ val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+ if (newP fastEquals p) {
+ Join(p, resolved, LeftSemi, conditions.reduceOption(And))
+ } else {
+ Project(p.output,
+ Join(newP, resolved, LeftSemi, conditions.reduceOption(And)))
+ }
case (p, Not(in: InSubQuery)) =>
- val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+ val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p)
// This is a NULL-aware (left) anti join (NAAJ).
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
@@ -1553,7 +1582,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
// if performance matters to you.
- Join(p, resolved, LeftAnti, Option(Or(anyNull, condition)))
+ if (newP fastEquals p) {
+ Join(p, resolved, LeftAnti, Option(Or(anyNull, condition)))
+ } else {
+ Project(p.output,
+ Join(newP, resolved, LeftAnti, Option(Or(anyNull, condition))))
+ }
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 7191936699..f5439d70ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -35,6 +35,10 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
plan transformAllExpressions {
case s: ScalarSubquery =>
ScalarSubquery(s.query, ExprId(0))
+ case s: InSubQuery =>
+ InSubQuery(s.value, s.query, ExprId(0))
+ case e: Exists =>
+ Exists(e.query, ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index d69ef08735..d182495757 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -125,21 +125,21 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("EXISTS predicate subquery") {
checkAnswer(
- sql("select * from l where exists(select * from r where l.a = r.c)"),
+ sql("select * from l where exists (select * from r where l.a = r.c)"),
Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
checkAnswer(
- sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"),
+ sql("select * from l where exists (select * from r where l.a = r.c) and l.a <= 2"),
Row(2, 1.0) :: Row(2, 1.0) :: Nil)
}
test("NOT EXISTS predicate subquery") {
checkAnswer(
- sql("select * from l where not exists(select * from r where l.a = r.c)"),
+ sql("select * from l where not exists (select * from r where l.a = r.c)"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil)
checkAnswer(
- sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"),
+ sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) ::
Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
}
@@ -160,20 +160,20 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("NOT IN predicate subquery") {
checkAnswer(
- sql("select * from l where a not in(select c from r)"),
+ sql("select * from l where a not in (select c from r)"),
Nil)
checkAnswer(
- sql("select * from l where a not in(select c from r where c is not null)"),
+ sql("select * from l where a not in (select c from r where c is not null)"),
Row(1, 2.0) :: Row(1, 2.0) :: Nil)
checkAnswer(
- sql("select * from l where a not in(select c from t where b < d)"),
+ sql("select * from l where a not in (select c from t where b < d)"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil)
// Empty sub-query
checkAnswer(
- sql("select * from l where a not in(select c from r where c > 10 and b < d)"),
+ sql("select * from l where a not in (select c from r where c > 10 and b < d)"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) ::
Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
@@ -181,11 +181,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
test("complex IN predicate subquery") {
checkAnswer(
- sql("select * from l where (a, b) not in(select c, d from r)"),
+ sql("select * from l where (a, b) not in (select c, d from r)"),
Nil)
checkAnswer(
- sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"),
+ sql("select * from l where (a, b) not in (select c, d from t) and (a + b) is not null"),
Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil)
}
+
+ test("same column in subquery and outer table") {
+ checkAnswer(
+ sql("select a from l l1 where a in (select a from l where a < 3 group by a)"),
+ Row(1) :: Row(1) :: Row(2) :: Row(2) :: Nil
+ )
+ }
}