aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-04-19 15:16:02 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-19 15:16:02 -0700
commitda8859226e09aa6ebcf6a1c5c1369dec3c216eac (patch)
treea72601d6d067bf81e5531e4de7d93f226186aef5
parent3c91afec20607e0d853433a904105ee22df73c73 (diff)
downloadspark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.tar.gz
spark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.tar.bz2
spark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.zip
[SPARK-4226] [SQL] Support IN/EXISTS Subqueries
### What changes were proposed in this pull request? This PR adds support for in/exists predicate subqueries to Spark. Predicate sub-queries are used as a filtering condition in a query (this is the only supported use case). A predicate sub-query comes in two forms: - `[NOT] EXISTS(subquery)` - `[NOT] IN (subquery)` This PR is (loosely) based on the work of davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/9055). They should be credited for the work they did. ### How was this patch tested? Modified parsing unit tests. Added tests to `org.apache.spark.sql.SQLQuerySuite` cc rxin, davies & chenghao-intel Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12306 from hvanhovell/SPARK-4226.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala84
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala115
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala58
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala53
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala98
12 files changed, 476 insertions, 42 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0e2fd43983..236476900a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -855,25 +855,35 @@ class Analyzer(
}
/**
- * This rule resolve subqueries inside expressions.
+ * This rule resolves sub-queries inside expressions.
*
- * Note: CTE are handled in CTESubstitution.
+ * Note: CTEs are handled in CTESubstitution.
*/
object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
- private def hasSubquery(e: Expression): Boolean = {
- e.find(_.isInstanceOf[SubqueryExpression]).isDefined
- }
-
- private def hasSubquery(q: LogicalPlan): Boolean = {
- q.expressions.exists(hasSubquery)
+ /**
+ * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a
+ * sub-query by using the plan the predicates should be correlated to.
+ */
+ private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = {
+ q transformUp {
+ case f @ Filter(cond, child) if child.resolved && !f.resolved =>
+ val newCond = resolveExpression(cond, p, throws = false)
+ if (!cond.fastEquals(newCond)) {
+ Filter(newCond, child)
+ } else {
+ f
+ }
+ }
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case q: LogicalPlan if q.childrenResolved && hasSubquery(q) =>
+ case q: LogicalPlan if q.childrenResolved =>
q transformExpressions {
case e: SubqueryExpression if !e.query.resolved =>
- e.withNewPlan(execute(e.query))
+ // First resolve as much of the sub-query as possible. After that we use the children of
+ // this plan to resolve the remaining correlated predicates.
+ e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates))
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index d6a8c3eec8..45e4d535c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.UsingJoin
+import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
/**
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
-trait CheckAnalysis {
+trait CheckAnalysis extends PredicateHelper {
/**
* Override to provide additional checks for correct analysis.
@@ -110,6 +110,39 @@ trait CheckAnalysis {
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
+ case f @ Filter(condition, child) =>
+ // Make sure that no correlated reference is below Aggregates, Outer Joins and on the
+ // right hand side of Unions.
+ lazy val attributes = child.outputSet
+ def failOnCorrelatedReference(
+ p: LogicalPlan,
+ message: String): Unit = p.transformAllExpressions {
+ case e: NamedExpression if attributes.contains(e) =>
+ failAnalysis(s"Accessing outer query column is not allowed in $message: $e")
+ }
+ def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach {
+ case a @ Aggregate(_, _, source) =>
+ failOnCorrelatedReference(source, "an AGGREATE")
+ case j @ Join(left, _, RightOuter, _) =>
+ failOnCorrelatedReference(left, "a RIGHT OUTER JOIN")
+ case j @ Join(_, right, jt, _) if jt != Inner =>
+ failOnCorrelatedReference(right, "a LEFT (OUTER) JOIN")
+ case Union(_ :: xs) =>
+ xs.foreach(failOnCorrelatedReference(_, "a UNION"))
+ case s: SetOperation =>
+ failOnCorrelatedReference(s.right, "an INTERSECT/EXCEPT")
+ case _ =>
+ }
+ splitConjunctivePredicates(condition).foreach {
+ case p: PredicateSubquery =>
+ checkForCorrelatedReferences(p)
+ case Not(p: PredicateSubquery) =>
+ checkForCorrelatedReferences(p)
+ case e if PredicateSubquery.hasPredicateSubquery(e) =>
+ failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e")
+ case e =>
+ }
+
case j @ Join(_, _, UsingJoin(_, cols), _) =>
val from = operator.inputSet.map(_.name).mkString(", ")
failAnalysis(
@@ -209,6 +242,9 @@ trait CheckAnalysis {
| but one table has '${firstError.output.length}' columns and another table has
| '${s.children.head.output.length}' columns""".stripMargin)
+ case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
+ failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
+
case _ => // Fallbacks to the following checks
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 968bbdb1a5..cbee0e61f7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
/**
* An interface for subquery that is used in expressions.
*/
-abstract class SubqueryExpression extends LeafExpression {
+abstract class SubqueryExpression extends Expression {
/**
* The logical plan of the query.
@@ -61,6 +61,8 @@ case class ScalarSubquery(
override def dataType: DataType = query.schema.fields.head.dataType
+ override def children: Seq[Expression] = Nil
+
override def checkInputDataTypes(): TypeCheckResult = {
if (query.schema.length != 1) {
TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
@@ -77,3 +79,81 @@ case class ScalarSubquery(
override def toString: String = s"subquery#${exprId.id}"
}
+
+/**
+ * A predicate subquery checks the existence of a value in a sub-query. We currently only allow
+ * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will
+ * be rewritten into a left semi/anti join during analysis.
+ */
+abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate {
+ override def nullable: Boolean = false
+ override def plan: LogicalPlan = SubqueryAlias(prettyName, query)
+}
+
+object PredicateSubquery {
+ def hasPredicateSubquery(e: Expression): Boolean = {
+ e.find(_.isInstanceOf[PredicateSubquery]).isDefined
+ }
+}
+
+/**
+ * The [[InSubQuery]] predicate checks the existence of a value in a sub-query. For example (SQL):
+ * {{{
+ * SELECT *
+ * FROM a
+ * WHERE a.id IN (SELECT id
+ * FROM b)
+ * }}}
+ */
+case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery {
+ override def children: Seq[Expression] = value :: Nil
+ override lazy val resolved: Boolean = value.resolved && query.resolved
+ override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan)
+
+ /**
+ * The unwrapped value side expressions.
+ */
+ lazy val expressions: Seq[Expression] = value match {
+ case CreateStruct(cols) => cols
+ case col => Seq(col)
+ }
+
+ /**
+ * Check if the number of columns and the data types on both sides match.
+ */
+ override def checkInputDataTypes(): TypeCheckResult = {
+ // Check the number of arguments.
+ if (expressions.length != query.output.length) {
+ TypeCheckResult.TypeCheckFailure(
+ s"The number of fields in the value (${expressions.length}) does not match with " +
+ s"the number of columns in the subquery (${query.output.length})")
+ }
+
+ // Check the argument types.
+ expressions.zip(query.output).zipWithIndex.foreach {
+ case ((e, a), i) if e.dataType != a.dataType =>
+ TypeCheckResult.TypeCheckFailure(
+ s"The data type of value[$i](${e.dataType}) does not match " +
+ s"subquery column '${a.name}' (${a.dataType}).")
+ case _ =>
+ }
+
+ TypeCheckResult.TypeCheckSuccess
+ }
+}
+
+/**
+ * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition.
+ * For example (SQL):
+ * {{{
+ * SELECT *
+ * FROM a
+ * WHERE EXISTS (SELECT *
+ * FROM b
+ * WHERE b.id = a.id)
+ * }}}
+ */
+case class Exists(query: LogicalPlan) extends PredicateSubquery {
+ override def children: Seq[Expression] = Nil
+ override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0a5232b2d4..ecc2d773e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.annotation.tailrec
import scala.collection.immutable.HashSet
+import scala.collection.mutable
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{InSubQuery, _}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
@@ -47,6 +48,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
// However, because we also use the analyzer to canonicalized queries (for view definition),
// we do not eliminate subqueries or compute current time in the analyzer.
Batch("Finish Analysis", Once,
+ RewritePredicateSubquery,
EliminateSubqueryAliases,
ComputeCurrentTime,
GetCurrentDatabase(sessionCatalog),
@@ -1446,3 +1448,114 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates
+ * are supported:
+ * a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved conditions in Filter
+ * will be pulled out as the join conditions.
+ * b. IN/NOT IN will be rewritten as semi/anti join, unresolved conditions in the Filter will
+ * be pulled out as join conditions, value = selected column will also be used as join
+ * condition.
+ */
+object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
+ /**
+ * Pull out all correlated predicates from a given sub-query. This method removes the correlated
+ * predicates from sub-query [[Filter]]s and adds the references of these predicates to
+ * all intermediate [[Project]] clauses (if they are missing) in order to be able to evaluate the
+ * predicates in the join condition.
+ *
+ * This method returns the rewritten sub-query and the combined (AND) extracted predicate.
+ */
+ private def pullOutCorrelatedPredicates(
+ subquery: LogicalPlan,
+ query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+ val references = query.outputSet
+ val predicateMap = mutable.Map.empty[LogicalPlan, Seq[Expression]]
+ val transformed = subquery transformUp {
+ case f @ Filter(cond, child) =>
+ // Find all correlated predicates.
+ val (correlated, local) = splitConjunctivePredicates(cond).partition { e =>
+ e.references.intersect(references).nonEmpty
+ }
+ // Rewrite the filter without the correlated predicates if any.
+ correlated match {
+ case Nil => f
+ case xs if local.nonEmpty =>
+ val newFilter = Filter(local.reduce(And), child)
+ predicateMap += newFilter -> correlated
+ newFilter
+ case xs =>
+ predicateMap += child -> correlated
+ child
+ }
+ case p @ Project(expressions, child) =>
+ // Find all pulled out predicates defined in the Project's subtree.
+ val localPredicates = p.collect(predicateMap).flatten
+
+ // Determine which correlated predicate references are missing from this project.
+ val localPredicateReferences = localPredicates
+ .map(_.references)
+ .reduceOption(_ ++ _)
+ .getOrElse(AttributeSet.empty)
+ val missingReferences = localPredicateReferences -- p.references -- query.outputSet
+
+ // Create a new project if we need to add missing references.
+ if (missingReferences.nonEmpty) {
+ Project(expressions ++ missingReferences, child)
+ } else {
+ p
+ }
+ }
+ (transformed, predicateMap.values.flatten.toSeq)
+ }
+
+ /**
+ * Prepare an [[InSubQuery]] by rewriting it (in case of correlated predicates) and by
+ * constructing the required join condition. Both the rewritten subquery and the constructed
+ * join condition are returned.
+ */
+ private def pullOutCorrelatedPredicates(
+ in: InSubQuery,
+ query: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+ val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query)
+ val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled)
+ (resolved, conditions)
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case f @ Filter(condition, child) =>
+ val (withSubquery, withoutSubquery) =
+ splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery)
+
+ // Construct the pruned filter condition.
+ val newFilter: LogicalPlan = withoutSubquery match {
+ case Nil => child
+ case conditions => Filter(conditions.reduce(And), child)
+ }
+
+ // Filter the plan by applying left semi and left anti joins.
+ withSubquery.foldLeft(newFilter) {
+ case (p, Exists(sub)) =>
+ val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
+ Join(p, resolved, LeftSemi, conditions.reduceOption(And))
+ case (p, Not(Exists(sub))) =>
+ val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p)
+ Join(p, resolved, LeftAnti, conditions.reduceOption(And))
+ case (p, in: InSubQuery) =>
+ val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+ Join(p, resolved, LeftSemi, conditions.reduceOption(And))
+ case (p, Not(in: InSubQuery)) =>
+ val (resolved, conditions) = pullOutCorrelatedPredicates(in, p)
+ // This is a NULL-aware (left) anti join (NAAJ).
+ // Construct the condition. A NULL in one of the conditions is regarded as a positive
+ // result; such a row will be filtered out by the Anti-Join operator.
+ val anyNull = conditions.map(IsNull).reduceLeft(Or)
+ val condition = conditions.reduceLeft(And)
+
+ // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS
+ // if performance matters to you.
+ Join(p, resolved, LeftAnti, Option(Or(anyNull, condition)))
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index aa59f3fb2a..1c067621df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -391,9 +391,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// Having
val withHaving = withProject.optional(having) {
- // Note that we added a cast to boolean. If the expression itself is already boolean,
- // the optimizer will get rid of the unnecessary cast.
- Filter(Cast(expression(having), BooleanType), withProject)
+ // Note that we add a cast to non-predicate expressions. If the expression itself is
+ // already boolean, the optimizer will get rid of the unnecessary cast.
+ val predicate = expression(having) match {
+ case p: Predicate => p
+ case e => Cast(e, BooleanType)
+ }
+ Filter(predicate, withProject)
}
// Distinct
@@ -866,10 +870,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
/**
- * Create a filtering correlated sub-query. This is not supported yet.
+ * Create a filtering correlated sub-query (EXISTS).
*/
override def visitExists(ctx: ExistsContext): Expression = {
- throw new ParseException("EXISTS clauses are not supported.", ctx)
+ Exists(plan(ctx.query))
}
/**
@@ -944,7 +948,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
GreaterThanOrEqual(e, expression(ctx.lower)),
LessThanOrEqual(e, expression(ctx.upper))))
case SqlBaseParser.IN if ctx.query != null =>
- throw new ParseException("IN with a Sub-query is currently not supported.", ctx)
+ invertIfNotDefined(InSubQuery(e, plan(ctx.query)))
case SqlBaseParser.IN =>
invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
case SqlBaseParser.LIKE =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index ad101d1c40..a90636d278 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
+import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@@ -444,4 +444,60 @@ class AnalysisErrorSuite extends AnalysisTest {
assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
}
+
+ test("PredicateSubQuery is used outside of a filter") {
+ val a = AttributeReference("a", IntegerType)()
+ val b = AttributeReference("b", IntegerType)()
+ val plan = Project(
+ Seq(a, Alias(InSubQuery(a, LocalRelation(b)), "c")()),
+ LocalRelation(a))
+ assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil)
+ }
+
+ test("PredicateSubQuery is used is a nested condition") {
+ val a = AttributeReference("a", IntegerType)()
+ val b = AttributeReference("b", IntegerType)()
+ val c = AttributeReference("c", BooleanType)()
+ val plan1 = Filter(Cast(InSubQuery(a, LocalRelation(b)), BooleanType), LocalRelation(a))
+ assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
+
+ val plan2 = Filter(Or(InSubQuery(a, LocalRelation(b)), c), LocalRelation(a, c))
+ assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil)
+ }
+
+ test("PredicateSubQuery correlated predicate is nested in an illegal plan") {
+ val a = AttributeReference("a", IntegerType)()
+ val b = AttributeReference("b", IntegerType)()
+ val c = AttributeReference("c", IntegerType)()
+
+ val plan1 = Filter(
+ Exists(
+ Join(
+ LocalRelation(b),
+ Filter(EqualTo(a, c), LocalRelation(c)),
+ LeftOuter,
+ Option(EqualTo(b, c)))),
+ LocalRelation(a))
+ assertAnalysisError(plan1, "Accessing outer query column is not allowed in" :: Nil)
+
+ val plan2 = Filter(
+ Exists(
+ Join(
+ Filter(EqualTo(a, c), LocalRelation(c)),
+ LocalRelation(b),
+ RightOuter,
+ Option(EqualTo(b, c)))),
+ LocalRelation(a))
+ assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil)
+
+ val plan3 = Filter(
+ Exists(Aggregate(Seq.empty, Seq.empty, Filter(EqualTo(a, c), LocalRelation(c)))),
+ LocalRelation(a))
+ assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
+
+ val plan4 = Filter(
+ Exists(Union(LocalRelation(b), Filter(EqualTo(a, c), LocalRelation(c)))),
+ LocalRelation(a))
+ assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
index db96bfb652..6da3eaea3d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
@@ -60,8 +60,8 @@ class ErrorParserSuite extends SparkFunSuite {
intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0,
"Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported",
"^^^")
- intercept("select * from r where a in (select * from t)", 1, 24,
- "IN with a Sub-query is currently not supported",
- "------------------------^^^")
+ intercept("select * from r except all select * from t", 1, 0,
+ "EXCEPT ALL is not supported",
+ "^^^")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index 6f40ec67ec..d1dc8d621f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -113,7 +113,9 @@ class ExpressionParserSuite extends PlanTest {
}
test("exists expression") {
- intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported")
+ assertEqual(
+ "exists (select 1 from b where b.x = a.x)",
+ Exists(table("b").where(Symbol("b.x") === Symbol("a.x")).select(1)))
}
test("comparison expressions") {
@@ -139,7 +141,9 @@ class ExpressionParserSuite extends PlanTest {
}
test("in sub-query") {
- intercept("a in (select b from c)", "IN with a Sub-query is currently not supported")
+ assertEqual(
+ "a in (select b from c)",
+ InSubQuery('a, table("c").select('b)))
}
test("like expressions") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 411e2372f2..a1ca55c262 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -107,7 +107,7 @@ class PlanParserSuite extends PlanTest {
assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
assertEqual(
"select a, b from db.c having x < 1",
- table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType)))
+ table("db", "c").select('a, 'b).where('x < 1))
assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
}
@@ -405,7 +405,7 @@ class PlanParserSuite extends PlanTest {
"select g from t group by g having a > (select b from s)",
table("t")
.groupBy('g)('g)
- .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType)))
+ .where('a > ScalarSubquery(table("s").select('b))))
}
test("table reference") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index b3e8b37a2e..71b6a97852 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -18,8 +18,9 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.{expressions, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
@@ -42,6 +43,7 @@ case class ScalarSubquery(
override def plan: SparkPlan = Subquery(simpleString, executedPlan)
override def dataType: DataType = executedPlan.schema.fields.head.dataType
+ override def children: Seq[Expression] = Nil
override def nullable: Boolean = true
override def toString: String = s"subquery#${exprId.id}"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 2dca792c83..cbacb5e103 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import java.util.{Locale, TimeZone}
+import java.util.{ArrayDeque, Locale, TimeZone}
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
@@ -35,6 +35,8 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.streaming.MemoryPlan
import org.apache.spark.sql.types.ObjectType
+
+
abstract class QueryTest extends PlanTest {
protected def sqlContext: SQLContext
@@ -47,6 +49,7 @@ abstract class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
+ *
* @param df the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
@@ -119,6 +122,7 @@ abstract class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
+ *
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
@@ -158,6 +162,7 @@ abstract class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer is within absTol of the expected result.
+ *
* @param dataFrame the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param absTol the absolute tolerance between actual and expected answers.
@@ -198,7 +203,10 @@ abstract class QueryTest extends PlanTest {
}
private def checkJsonFormat(df: DataFrame): Unit = {
+ // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
+ // RDD and Data resolution does not break.
val logicalPlan = df.queryExecution.analyzed
+
// bypass some cases that we can't handle currently.
logicalPlan.transform {
case _: ObjectConsumer => return
@@ -236,9 +244,27 @@ abstract class QueryTest extends PlanTest {
// RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
// these non-serializable stuff, and use these original ones to replace the null-placeholders
// in the logical plans parsed from JSON.
- var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l }
- var localRelations = logicalPlan.collect { case l: LocalRelation => l }
- var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i }
+ val logicalRDDs = new ArrayDeque[LogicalRDD]()
+ val localRelations = new ArrayDeque[LocalRelation]()
+ val inMemoryRelations = new ArrayDeque[InMemoryRelation]()
+ def collectData: (LogicalPlan => Unit) = {
+ case l: LogicalRDD =>
+ logicalRDDs.offer(l)
+ case l: LocalRelation =>
+ localRelations.offer(l)
+ case i: InMemoryRelation =>
+ inMemoryRelations.offer(i)
+ case p =>
+ p.expressions.foreach {
+ _.foreach {
+ case s: SubqueryExpression =>
+ s.query.foreach(collectData)
+ case _ =>
+ }
+ }
+ }
+ logicalPlan.foreach(collectData)
+
val jsonBackPlan = try {
TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
@@ -253,18 +279,15 @@ abstract class QueryTest extends PlanTest {
""".stripMargin, e)
}
- val normalized2 = jsonBackPlan transformDown {
+ def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
case l: LogicalRDD =>
- val origin = logicalRDDs.head
- logicalRDDs = logicalRDDs.drop(1)
+ val origin = logicalRDDs.pop()
LogicalRDD(l.output, origin.rdd)(sqlContext)
case l: LocalRelation =>
- val origin = localRelations.head
- localRelations = localRelations.drop(1)
+ val origin = localRelations.pop()
l.copy(data = origin.data)
case l: InMemoryRelation =>
- val origin = inMemoryRelations.head
- inMemoryRelations = inMemoryRelations.drop(1)
+ val origin = inMemoryRelations.pop()
InMemoryRelation(
l.output,
l.useCompression,
@@ -275,7 +298,13 @@ abstract class QueryTest extends PlanTest {
origin.cachedColumnBuffers,
l._statistics,
origin._batchStats)
+ case p =>
+ p.transformExpressions {
+ case s: SubqueryExpression =>
+ s.withNewPlan(s.query.transformDown(renormalize))
+ }
}
+ val normalized2 = jsonBackPlan.transformDown(renormalize)
assert(logicalRDDs.isEmpty)
assert(localRelations.isEmpty)
@@ -309,6 +338,7 @@ object QueryTest {
* If there was exception during the execution or the contents of the DataFrame does not
* match the expected result, an error message will be returned. Otherwise, a [[None]] will
* be returned.
+ *
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
@@ -383,6 +413,7 @@ object QueryTest {
/**
* Runs the plan and makes sure the answer is within absTol of the expected result.
+ *
* @param actualAnswer the actual result in a [[Row]].
* @param expectedAnswer the expected result in a[[Row]].
* @param absTol the absolute tolerance between actual and expected answers.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 21b19fe7df..5742983fb9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,6 +22,38 @@ import org.apache.spark.sql.test.SharedSQLContext
class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
+ setupTestData()
+
+ val row = identity[(java.lang.Integer, java.lang.Double)](_)
+
+ lazy val l = Seq(
+ row(1, 2.0),
+ row(1, 2.0),
+ row(2, 1.0),
+ row(2, 1.0),
+ row(3, 3.0),
+ row(null, null),
+ row(null, 5.0),
+ row(6, null)).toDF("a", "b")
+
+ lazy val r = Seq(
+ row(2, 3.0),
+ row(2, 3.0),
+ row(3, 2.0),
+ row(4, 1.0),
+ row(null, null),
+ row(null, 5.0),
+ row(6, null)).toDF("c", "d")
+
+ lazy val t = r.filter($"c".isNotNull && $"d".isNotNull)
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ l.registerTempTable("l")
+ r.registerTempTable("r")
+ t.registerTempTable("t")
+ }
+
test("simple uncorrelated scalar subquery") {
assertResult(Array(Row(1))) {
sql("select (select 1 as b) as b").collect()
@@ -80,4 +112,70 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
" where key = (select max(key) from subqueryData) - 1)").collect()
}
}
+
+ test("EXISTS predicate subquery") {
+ checkAnswer(
+ sql("select * from l where exists(select * from r where l.a = r.c)"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Nil)
+ }
+
+ test("NOT EXISTS predicate subquery") {
+ checkAnswer(
+ sql("select * from l where not exists(select * from r where l.a = r.c)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) ::
+ Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+ }
+
+ test("IN predicate subquery") {
+ checkAnswer(
+ sql("select * from l where l.a in (select c from r)"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where l.a in (select c from r where l.b < r.d)"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null"),
+ Row(3, 3.0) :: Nil)
+ }
+
+ test("NOT IN predicate subquery") {
+ checkAnswer(
+ sql("select * from l where a not in(select c from r)"),
+ Nil)
+
+ checkAnswer(
+ sql("select * from l where a not in(select c from r where c is not null)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where a not in(select c from t where b < d)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil)
+
+ // Empty sub-query
+ checkAnswer(
+ sql("select * from l where a not in(select c from r where c > 10 and b < d)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) ::
+ Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+
+ }
+
+ test("complex IN predicate subquery") {
+ checkAnswer(
+ sql("select * from l where (a, b) not in(select c, d from r)"),
+ Nil)
+
+ checkAnswer(
+ sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil)
+ }
}