From da8859226e09aa6ebcf6a1c5c1369dec3c216eac Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 19 Apr 2016 15:16:02 -0700 Subject: [SPARK-4226] [SQL] Support IN/EXISTS Subqueries ### What changes were proposed in this pull request? This PR adds support for in/exists predicate subqueries to Spark. Predicate sub-queries are used as a filtering condition in a query (this is the only supported use case). A predicate sub-query comes in two forms: - `[NOT] EXISTS(subquery)` - `[NOT] IN (subquery)` This PR is (loosely) based on the work of davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/9055). They should be credited for the work they did. ### How was this patch tested? Modified parsing unit tests. Added tests to `org.apache.spark.sql.SQLQuerySuite` cc rxin, davies & chenghao-intel Author: Herman van Hovell Closes #12306 from hvanhovell/SPARK-4226. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 30 ++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 40 ++++++- .../spark/sql/catalyst/expressions/subquery.scala | 84 ++++++++++++++- .../spark/sql/catalyst/optimizer/Optimizer.scala | 115 ++++++++++++++++++++- .../spark/sql/catalyst/parser/AstBuilder.scala | 16 +-- .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 58 ++++++++++- .../sql/catalyst/parser/ErrorParserSuite.scala | 6 +- .../catalyst/parser/ExpressionParserSuite.scala | 8 +- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +- .../org/apache/spark/sql/execution/subquery.scala | 6 +- .../scala/org/apache/spark/sql/QueryTest.scala | 53 ++++++++-- .../scala/org/apache/spark/sql/SubquerySuite.scala | 98 ++++++++++++++++++ 12 files changed, 476 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0e2fd43983..236476900a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -855,25 +855,35 @@ class Analyzer( } /** - * This rule resolve subqueries inside expressions. + * This rule resolves sub-queries inside expressions. * - * Note: CTE are handled in CTESubstitution. + * Note: CTEs are handled in CTESubstitution. */ object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def hasSubquery(e: Expression): Boolean = { - e.find(_.isInstanceOf[SubqueryExpression]).isDefined - } - - private def hasSubquery(q: LogicalPlan): Boolean = { - q.expressions.exists(hasSubquery) + /** + * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a + * sub-query by using the plan the predicates should be correlated to. + */ + private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = { + q transformUp { + case f @ Filter(cond, child) if child.resolved && !f.resolved => + val newCond = resolveExpression(cond, p, throws = false) + if (!cond.fastEquals(newCond)) { + Filter(newCond, child) + } else { + f + } + } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case q: LogicalPlan if q.childrenResolved && hasSubquery(q) => + case q: LogicalPlan if q.childrenResolved => q transformExpressions { case e: SubqueryExpression if !e.query.resolved => - e.withNewPlan(execute(e.query)) + // First resolve as much of the sub-query as possible. After that we use the children of + // this plan to resolve the remaining correlated predicates. + e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d6a8c3eec8..45e4d535c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.UsingJoin +import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ /** * Throws user facing errors when passed invalid queries that fail to analyze. */ -trait CheckAnalysis { +trait CheckAnalysis extends PredicateHelper { /** * Override to provide additional checks for correct analysis. @@ -110,6 +110,39 @@ trait CheckAnalysis { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case f @ Filter(condition, child) => + // Make sure that no correlated reference is below Aggregates, Outer Joins and on the + // right hand side of Unions. + lazy val attributes = child.outputSet + def failOnCorrelatedReference( + p: LogicalPlan, + message: String): Unit = p.transformAllExpressions { + case e: NamedExpression if attributes.contains(e) => + failAnalysis(s"Accessing outer query column is not allowed in $message: $e") + } + def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach { + case a @ Aggregate(_, _, source) => + failOnCorrelatedReference(source, "an AGGREATE") + case j @ Join(left, _, RightOuter, _) => + failOnCorrelatedReference(left, "a RIGHT OUTER JOIN") + case j @ Join(_, right, jt, _) if jt != Inner => + failOnCorrelatedReference(right, "a LEFT (OUTER) JOIN") + case Union(_ :: xs) => + xs.foreach(failOnCorrelatedReference(_, "a UNION")) + case s: SetOperation => + failOnCorrelatedReference(s.right, "an INTERSECT/EXCEPT") + case _ => + } + splitConjunctivePredicates(condition).foreach { + case p: PredicateSubquery => + checkForCorrelatedReferences(p) + case Not(p: PredicateSubquery) => + checkForCorrelatedReferences(p) + case e if PredicateSubquery.hasPredicateSubquery(e) => + failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e") + case e => + } + case j @ Join(_, _, UsingJoin(_, cols), _) => val from = operator.inputSet.map(_.name).mkString(", ") failAnalysis( @@ -209,6 +242,9 @@ trait CheckAnalysis { | but one table has '${firstError.output.length}' columns and another table has | '${s.children.head.output.length}' columns""".stripMargin) + case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => + failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 968bbdb1a5..cbee0e61f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ /** * An interface for subquery that is used in expressions. */ -abstract class SubqueryExpression extends LeafExpression { +abstract class SubqueryExpression extends Expression { /** * The logical plan of the query. @@ -61,6 +61,8 @@ case class ScalarSubquery( override def dataType: DataType = query.schema.fields.head.dataType + override def children: Seq[Expression] = Nil + override def checkInputDataTypes(): TypeCheckResult = { if (query.schema.length != 1) { TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + @@ -77,3 +79,81 @@ case class ScalarSubquery( override def toString: String = s"subquery#${exprId.id}" } + +/** + * A predicate subquery checks the existence of a value in a sub-query. We currently only allow + * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will + * be rewritten into a left semi/anti join during analysis. + */ +abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate { + override def nullable: Boolean = false + override def plan: LogicalPlan = SubqueryAlias(prettyName, query) +} + +object PredicateSubquery { + def hasPredicateSubquery(e: Expression): Boolean = { + e.find(_.isInstanceOf[PredicateSubquery]).isDefined + } +} + +/** + * The [[InSubQuery]] predicate checks the existence of a value in a sub-query. For example (SQL): + * {{{ + * SELECT * + * FROM a + * WHERE a.id IN (SELECT id + * FROM b) + * }}} + */ +case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery { + override def children: Seq[Expression] = value :: Nil + override lazy val resolved: Boolean = value.resolved && query.resolved + override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan) + + /** + * The unwrapped value side expressions. + */ + lazy val expressions: Seq[Expression] = value match { + case CreateStruct(cols) => cols + case col => Seq(col) + } + + /** + * Check if the number of columns and the data types on both sides match. + */ + override def checkInputDataTypes(): TypeCheckResult = { + // Check the number of arguments. + if (expressions.length != query.output.length) { + TypeCheckResult.TypeCheckFailure( + s"The number of fields in the value (${expressions.length}) does not match with " + + s"the number of columns in the subquery (${query.output.length})") + } + + // Check the argument types. + expressions.zip(query.output).zipWithIndex.foreach { + case ((e, a), i) if e.dataType != a.dataType => + TypeCheckResult.TypeCheckFailure( + s"The data type of value[$i](${e.dataType}) does not match " + + s"subquery column '${a.name}' (${a.dataType}).") + case _ => + } + + TypeCheckResult.TypeCheckSuccess + } +} + +/** + * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * For example (SQL): + * {{{ + * SELECT * + * FROM a + * WHERE EXISTS (SELECT * + * FROM b + * WHERE b.id = a.id) + * }}} + */ +case class Exists(query: LogicalPlan) extends PredicateSubquery { + override def children: Seq[Expression] = Nil + override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0a5232b2d4..ecc2d773e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet +import scala.collection.mutable import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{InSubQuery, _} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} @@ -47,6 +48,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // However, because we also use the analyzer to canonicalized queries (for view definition), // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, + RewritePredicateSubquery, EliminateSubqueryAliases, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), @@ -1446,3 +1448,114 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] { } } } + +/** + * This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates + * are supported: + * a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved conditions in Filter + * will be pulled out as the join conditions. + * b. IN/NOT IN will be rewritten as semi/anti join, unresolved conditions in the Filter will + * be pulled out as join conditions, value = selected column will also be used as join + * condition. + */ +object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { + /** + * Pull out all correlated predicates from a given sub-query. This method removes the correlated + * predicates from sub-query [[Filter]]s and adds the references of these predicates to + * all intermediate [[Project]] clauses (if they are missing) in order to be able to evaluate the + * predicates in the join condition. + * + * This method returns the rewritten sub-query and the combined (AND) extracted predicate. + */ + private def pullOutCorrelatedPredicates( + subquery: LogicalPlan, + query: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + val references = query.outputSet + val predicateMap = mutable.Map.empty[LogicalPlan, Seq[Expression]] + val transformed = subquery transformUp { + case f @ Filter(cond, child) => + // Find all correlated predicates. + val (correlated, local) = splitConjunctivePredicates(cond).partition { e => + e.references.intersect(references).nonEmpty + } + // Rewrite the filter without the correlated predicates if any. + correlated match { + case Nil => f + case xs if local.nonEmpty => + val newFilter = Filter(local.reduce(And), child) + predicateMap += newFilter -> correlated + newFilter + case xs => + predicateMap += child -> correlated + child + } + case p @ Project(expressions, child) => + // Find all pulled out predicates defined in the Project's subtree. + val localPredicates = p.collect(predicateMap).flatten + + // Determine which correlated predicate references are missing from this project. + val localPredicateReferences = localPredicates + .map(_.references) + .reduceOption(_ ++ _) + .getOrElse(AttributeSet.empty) + val missingReferences = localPredicateReferences -- p.references -- query.outputSet + + // Create a new project if we need to add missing references. + if (missingReferences.nonEmpty) { + Project(expressions ++ missingReferences, child) + } else { + p + } + } + (transformed, predicateMap.values.flatten.toSeq) + } + + /** + * Prepare an [[InSubQuery]] by rewriting it (in case of correlated predicates) and by + * constructing the required join condition. Both the rewritten subquery and the constructed + * join condition are returned. + */ + private def pullOutCorrelatedPredicates( + in: InSubQuery, + query: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query) + val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled) + (resolved, conditions) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(condition, child) => + val (withSubquery, withoutSubquery) = + splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery) + + // Construct the pruned filter condition. + val newFilter: LogicalPlan = withoutSubquery match { + case Nil => child + case conditions => Filter(conditions.reduce(And), child) + } + + // Filter the plan by applying left semi and left anti joins. + withSubquery.foldLeft(newFilter) { + case (p, Exists(sub)) => + val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) + Join(p, resolved, LeftSemi, conditions.reduceOption(And)) + case (p, Not(Exists(sub))) => + val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) + Join(p, resolved, LeftAnti, conditions.reduceOption(And)) + case (p, in: InSubQuery) => + val (resolved, conditions) = pullOutCorrelatedPredicates(in, p) + Join(p, resolved, LeftSemi, conditions.reduceOption(And)) + case (p, Not(in: InSubQuery)) => + val (resolved, conditions) = pullOutCorrelatedPredicates(in, p) + // This is a NULL-aware (left) anti join (NAAJ). + // Construct the condition. A NULL in one of the conditions is regarded as a positive + // result; such a row will be filtered out by the Anti-Join operator. + val anyNull = conditions.map(IsNull).reduceLeft(Or) + val condition = conditions.reduceLeft(And) + + // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS + // if performance matters to you. + Join(p, resolved, LeftAnti, Option(Or(anyNull, condition))) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index aa59f3fb2a..1c067621df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -391,9 +391,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Having val withHaving = withProject.optional(having) { - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(expression(having), BooleanType), withProject) + // Note that we add a cast to non-predicate expressions. If the expression itself is + // already boolean, the optimizer will get rid of the unnecessary cast. + val predicate = expression(having) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + Filter(predicate, withProject) } // Distinct @@ -866,10 +870,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a filtering correlated sub-query. This is not supported yet. + * Create a filtering correlated sub-query (EXISTS). */ override def visitExists(ctx: ExistsContext): Expression = { - throw new ParseException("EXISTS clauses are not supported.", ctx) + Exists(plan(ctx.query)) } /** @@ -944,7 +948,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + invertIfNotDefined(InSubQuery(e, plan(ctx.query))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ad101d1c40..a90636d278 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -444,4 +444,60 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil) } + + test("PredicateSubQuery is used outside of a filter") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val plan = Project( + Seq(a, Alias(InSubQuery(a, LocalRelation(b)), "c")()), + LocalRelation(a)) + assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) + } + + test("PredicateSubQuery is used is a nested condition") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val c = AttributeReference("c", BooleanType)() + val plan1 = Filter(Cast(InSubQuery(a, LocalRelation(b)), BooleanType), LocalRelation(a)) + assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil) + + val plan2 = Filter(Or(InSubQuery(a, LocalRelation(b)), c), LocalRelation(a, c)) + assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil) + } + + test("PredicateSubQuery correlated predicate is nested in an illegal plan") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val c = AttributeReference("c", IntegerType)() + + val plan1 = Filter( + Exists( + Join( + LocalRelation(b), + Filter(EqualTo(a, c), LocalRelation(c)), + LeftOuter, + Option(EqualTo(b, c)))), + LocalRelation(a)) + assertAnalysisError(plan1, "Accessing outer query column is not allowed in" :: Nil) + + val plan2 = Filter( + Exists( + Join( + Filter(EqualTo(a, c), LocalRelation(c)), + LocalRelation(b), + RightOuter, + Option(EqualTo(b, c)))), + LocalRelation(a)) + assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) + + val plan3 = Filter( + Exists(Aggregate(Seq.empty, Seq.empty, Filter(EqualTo(a, c), LocalRelation(c)))), + LocalRelation(a)) + assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) + + val plan4 = Filter( + Exists(Union(LocalRelation(b), Filter(EqualTo(a, c), LocalRelation(c)))), + LocalRelation(a)) + assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index db96bfb652..6da3eaea3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -60,8 +60,8 @@ class ErrorParserSuite extends SparkFunSuite { intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", "^^^") - intercept("select * from r where a in (select * from t)", 1, 24, - "IN with a Sub-query is currently not supported", - "------------------------^^^") + intercept("select * from r except all select * from t", 1, 0, + "EXCEPT ALL is not supported", + "^^^") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 6f40ec67ec..d1dc8d621f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -113,7 +113,9 @@ class ExpressionParserSuite extends PlanTest { } test("exists expression") { - intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + assertEqual( + "exists (select 1 from b where b.x = a.x)", + Exists(table("b").where(Symbol("b.x") === Symbol("a.x")).select(1))) } test("comparison expressions") { @@ -139,7 +141,9 @@ class ExpressionParserSuite extends PlanTest { } test("in sub-query") { - intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + assertEqual( + "a in (select b from c)", + InSubQuery('a, table("c").select('b))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411e2372f2..a1ca55c262 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -107,7 +107,7 @@ class PlanParserSuite extends PlanTest { assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) assertEqual( "select a, b from db.c having x < 1", - table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + table("db", "c").select('a, 'b).where('x < 1)) assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) } @@ -405,7 +405,7 @@ class PlanParserSuite extends PlanTest { "select g from t group by g having a > (select b from s)", table("t") .groupBy('g)('g) - .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + .where('a > ScalarSubquery(table("s").select('b)))) } test("table reference") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index b3e8b37a2e..71b6a97852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -42,6 +43,7 @@ case class ScalarSubquery( override def plan: SparkPlan = Subquery(simpleString, executedPlan) override def dataType: DataType = executedPlan.schema.fields.head.dataType + override def children: Seq[Expression] = Nil override def nullable: Boolean = true override def toString: String = s"subquery#${exprId.id}" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 2dca792c83..cbacb5e103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Locale, TimeZone} +import java.util.{ArrayDeque, Locale, TimeZone} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -35,6 +35,8 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan import org.apache.spark.sql.types.ObjectType + + abstract class QueryTest extends PlanTest { protected def sqlContext: SQLContext @@ -47,6 +49,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer + * * @param df the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output @@ -119,6 +122,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. + * * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ @@ -158,6 +162,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer is within absTol of the expected result. + * * @param dataFrame the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. * @param absTol the absolute tolerance between actual and expected answers. @@ -198,7 +203,10 @@ abstract class QueryTest extends PlanTest { } private def checkJsonFormat(df: DataFrame): Unit = { + // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that + // RDD and Data resolution does not break. val logicalPlan = df.queryExecution.analyzed + // bypass some cases that we can't handle currently. logicalPlan.transform { case _: ObjectConsumer => return @@ -236,9 +244,27 @@ abstract class QueryTest extends PlanTest { // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains // these non-serializable stuff, and use these original ones to replace the null-placeholders // in the logical plans parsed from JSON. - var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l } - var localRelations = logicalPlan.collect { case l: LocalRelation => l } - var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i } + val logicalRDDs = new ArrayDeque[LogicalRDD]() + val localRelations = new ArrayDeque[LocalRelation]() + val inMemoryRelations = new ArrayDeque[InMemoryRelation]() + def collectData: (LogicalPlan => Unit) = { + case l: LogicalRDD => + logicalRDDs.offer(l) + case l: LocalRelation => + localRelations.offer(l) + case i: InMemoryRelation => + inMemoryRelations.offer(i) + case p => + p.expressions.foreach { + _.foreach { + case s: SubqueryExpression => + s.query.foreach(collectData) + case _ => + } + } + } + logicalPlan.foreach(collectData) + val jsonBackPlan = try { TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) @@ -253,18 +279,15 @@ abstract class QueryTest extends PlanTest { """.stripMargin, e) } - val normalized2 = jsonBackPlan transformDown { + def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { case l: LogicalRDD => - val origin = logicalRDDs.head - logicalRDDs = logicalRDDs.drop(1) + val origin = logicalRDDs.pop() LogicalRDD(l.output, origin.rdd)(sqlContext) case l: LocalRelation => - val origin = localRelations.head - localRelations = localRelations.drop(1) + val origin = localRelations.pop() l.copy(data = origin.data) case l: InMemoryRelation => - val origin = inMemoryRelations.head - inMemoryRelations = inMemoryRelations.drop(1) + val origin = inMemoryRelations.pop() InMemoryRelation( l.output, l.useCompression, @@ -275,7 +298,13 @@ abstract class QueryTest extends PlanTest { origin.cachedColumnBuffers, l._statistics, origin._batchStats) + case p => + p.transformExpressions { + case s: SubqueryExpression => + s.withNewPlan(s.query.transformDown(renormalize)) + } } + val normalized2 = jsonBackPlan.transformDown(renormalize) assert(logicalRDDs.isEmpty) assert(localRelations.isEmpty) @@ -309,6 +338,7 @@ object QueryTest { * If there was exception during the execution or the contents of the DataFrame does not * match the expected result, an error message will be returned. Otherwise, a [[None]] will * be returned. + * * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ @@ -383,6 +413,7 @@ object QueryTest { /** * Runs the plan and makes sure the answer is within absTol of the expected result. + * * @param actualAnswer the actual result in a [[Row]]. * @param expectedAnswer the expected result in a[[Row]]. * @param absTol the absolute tolerance between actual and expected answers. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 21b19fe7df..5742983fb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -22,6 +22,38 @@ import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { import testImplicits._ + setupTestData() + + val row = identity[(java.lang.Integer, java.lang.Double)](_) + + lazy val l = Seq( + row(1, 2.0), + row(1, 2.0), + row(2, 1.0), + row(2, 1.0), + row(3, 3.0), + row(null, null), + row(null, 5.0), + row(6, null)).toDF("a", "b") + + lazy val r = Seq( + row(2, 3.0), + row(2, 3.0), + row(3, 2.0), + row(4, 1.0), + row(null, null), + row(null, 5.0), + row(6, null)).toDF("c", "d") + + lazy val t = r.filter($"c".isNotNull && $"d".isNotNull) + + protected override def beforeAll(): Unit = { + super.beforeAll() + l.registerTempTable("l") + r.registerTempTable("r") + t.registerTempTable("t") + } + test("simple uncorrelated scalar subquery") { assertResult(Array(Row(1))) { sql("select (select 1 as b) as b").collect() @@ -80,4 +112,70 @@ class SubquerySuite extends QueryTest with SharedSQLContext { " where key = (select max(key) from subqueryData) - 1)").collect() } } + + test("EXISTS predicate subquery") { + checkAnswer( + sql("select * from l where exists(select * from r where l.a = r.c)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + checkAnswer( + sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"), + Row(2, 1.0) :: Row(2, 1.0) :: Nil) + } + + test("NOT EXISTS predicate subquery") { + checkAnswer( + sql("select * from l where not exists(select * from r where l.a = r.c)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil) + + checkAnswer( + sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) + } + + test("IN predicate subquery") { + checkAnswer( + sql("select * from l where l.a in (select c from r)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + checkAnswer( + sql("select * from l where l.a in (select c from r where l.b < r.d)"), + Row(2, 1.0) :: Row(2, 1.0) :: Nil) + + checkAnswer( + sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null"), + Row(3, 3.0) :: Nil) + } + + test("NOT IN predicate subquery") { + checkAnswer( + sql("select * from l where a not in(select c from r)"), + Nil) + + checkAnswer( + sql("select * from l where a not in(select c from r where c is not null)"), + Row(1, 2.0) :: Row(1, 2.0) :: Nil) + + checkAnswer( + sql("select * from l where a not in(select c from t where b < d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil) + + // Empty sub-query + checkAnswer( + sql("select * from l where a not in(select c from r where c > 10 and b < d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: + Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) + + } + + test("complex IN predicate subquery") { + checkAnswer( + sql("select * from l where (a, b) not in(select c, d from r)"), + Nil) + + checkAnswer( + sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil) + } } -- cgit v1.2.3