diff options
author | Herman van Hovell <hvanhovell@questtec.nl> | 2016-05-02 16:32:31 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-05-02 16:32:31 -0700 |
commit | f362363d148e2df4549fed5c3fd1cf20d0848fd0 (patch) | |
tree | ae72eececa383e88ed8790acd98896c4fe52314d /sql/catalyst/src | |
parent | 917d05f43bddc1728735979fe7e62fe631b35e6f (diff) | |
download | spark-f362363d148e2df4549fed5c3fd1cf20d0848fd0.tar.gz spark-f362363d148e2df4549fed5c3fd1cf20d0848fd0.tar.bz2 spark-f362363d148e2df4549fed5c3fd1cf20d0848fd0.zip |
[SPARK-14785] [SQL] Support correlated scalar subqueries
## What changes were proposed in this pull request?
In this PR we add support for correlated scalar subqueries. An example of such a query is:
```SQL
select * from tbl1 a where a.value > (select max(value) from tbl2 b where b.key = a.key)
```
The implementation adds the `RewriteCorrelatedScalarSubquery` rule to the Optimizer. This rule plans these subqueries using `LEFT OUTER` joins. It currently supports rewrites for `Project`, `Aggregate` & `Filter` logical plans.
I could not find a well defined semantics for the use of scalar subqueries in an `Aggregate`. The current implementation currently evaluates the scalar subquery *before* aggregation. This means that you either have to make scalar subquery part of the grouping expression, or that you have to aggregate it further on. I am open to suggestions on this.
The implementation currently forces the uniqueness of a scalar subquery by enforcing that it is aggregated and that the resulting column is wrapped in an `AggregateExpression`.
## How was this patch tested?
Added tests to `SubquerySuite`.
Author: Herman van Hovell <hvanhovell@questtec.nl>
Closes #12822 from hvanhovell/SPARK-14785.
Diffstat (limited to 'sql/catalyst/src')
6 files changed, 148 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f8ab3f435..59af5b7095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1081,10 +1081,10 @@ class Analyzer( // Step 2: Pull out the predicates if the plan is resolved. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only - // needed for IN expressions. + // needed for Scalar and IN subqueries. if (requiredColumns > 0 && requiredColumns != current.output.size) { - failAnalysis(s"The number of fields in the value ($requiredColumns) does not " + - s"match with the number of columns in the subquery (${current.output.size})") + failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + + s"does not match the required number of columns ($requiredColumns)") } // Pullout predicates and construct a new plan. f.tupled(rewriteSubQuery(current, plans)) @@ -1099,8 +1099,11 @@ class Analyzer( */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { + case s @ ScalarSubquery(sub, conditions, exprId) + if sub.resolved && conditions.isEmpty && sub.output.size != 1 => + failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}") case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => - resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) + resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, exprId) => resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 6e3a14dfb9..800bf01abd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin} +import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -60,9 +60,6 @@ trait CheckAnalysis extends PredicateHelper { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") - case ScalarSubquery(_, conditions, _) if conditions.nonEmpty => - failAnalysis("Correlated scalar subqueries are not supported.") - case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => @@ -104,6 +101,36 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Window specification $s is not valid because $m") case None => w } + + case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => + // Make sure we are using equi-joins. + conditions.foreach { + case _: EqualTo | _: EqualNullSafe => // ok + case e => failAnalysis( + s"The correlated scalar subquery can only contain equality predicates: $e") + } + + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates which contain exactly one aggregate expressions. + // The analyzer has already checked that subquery contained only one output column, and + // added all the grouping expressions to the aggregate. + def checkAggregate(a: Aggregate): Unit = { + val aggregates = a.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + failAnalysis("The output of a correlated scalar subquery must be aggregated") + } + } + + query match { + case a: Aggregate => checkAggregate(a) + case Filter(_, a: Aggregate) => checkAggregate(a) + case Project(_, a: Aggregate) => checkAggregate(a) + case Project(_, Filter(_, a: Aggregate)) => checkAggregate(a) + case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") + } + s } operator match { @@ -220,6 +247,13 @@ trait CheckAnalysis extends PredicateHelper { | but one table has '${firstError.output.length}' columns and another table has | '${s.children.head.output.length}' columns""".stripMargin) + case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => + p match { + case _: Filter | _: Aggregate | _: Project => // Ok + case other => failAnalysis( + s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") + } + case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index eed062f8bc..5001f9a41e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -44,6 +44,15 @@ abstract class SubqueryExpression extends Expression { protected def conditionString: String = children.mkString("[", " && ", "]") } +object SubqueryExpression { + def hasCorrelatedSubquery(e: Expression): Boolean = { + e.find { + case e: SubqueryExpression if e.children.nonEmpty => true + case _ => false + }.isDefined + } +} + /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. @@ -55,28 +64,26 @@ case class ScalarSubquery( children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Unevaluable { - - override def plan: LogicalPlan = SubqueryAlias(toString, query) - override lazy val resolved: Boolean = childrenResolved && query.resolved - - override def dataType: DataType = query.schema.fields.head.dataType - - override def checkInputDataTypes(): TypeCheckResult = { - if (query.schema.length != 1) { - TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + - query.schema.length.toString) - } else { - TypeCheckResult.TypeCheckSuccess - } + override lazy val references: AttributeSet = { + if (query.resolved) super.references -- query.outputSet + else super.references } - + override def dataType: DataType = query.schema.fields.head.dataType override def foldable: Boolean = false override def nullable: Boolean = true - + override def plan: LogicalPlan = SubqueryAlias(toString, query) override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan) + override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" +} - override def toString: String = s"subquery#${exprId.id} $conditionString" +object ScalarSubquery { + def hasCorrelatedScalarSubquery(e: Expression): Boolean = { + e.find { + case e: ScalarSubquery if e.children.nonEmpty => true + case _ => false + }.isDefined + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e1c969f50f..a3ab89dc71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} @@ -100,6 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, + RewriteCorrelatedScalarSubquery, EliminateSerialization) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: @@ -1081,7 +1083,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { assert(input.size >= 2) if (input.size == 2) { val (joinConditions, others) = conditions.partition( - e => !PredicateSubquery.hasPredicateSubquery(e)) + e => !SubqueryExpression.hasCorrelatedSubquery(e)) val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And)) if (others.nonEmpty) { Filter(others.reduceLeft(And), join) @@ -1101,7 +1103,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val joinedRefs = left.outputSet ++ right.outputSet val (joinConditions, others) = conditions.partition( - e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e)) + e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) // should not have reference to same logical plan @@ -1134,7 +1136,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { * Returns whether the expression returns null or false when all inputs are nulls. */ private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false + if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val v = BindReferences.bindReference(e, attributes).eval(emptyRow) @@ -1203,7 +1205,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) - joinType match { case Inner => // push down the single side `where` condition into respective sides @@ -1212,7 +1213,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val (newJoinConditions, others) = - commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e)) + commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e)) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) val join = Join(newLeft, newRight, Inner, newJoinCond) @@ -1573,3 +1574,74 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } } + +/** + * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. + */ +object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { + /** + * Extract all correlated scalar subqueries from an expression. The subqueries are collected using + * the given collector. The expression is rewritten and returned. + */ + private def extractCorrelatedScalarSubqueries[E <: Expression]( + expression: E, + subqueries: ArrayBuffer[ScalarSubquery]): E = { + val newExpression = expression transform { + case s: ScalarSubquery if s.children.nonEmpty => + subqueries += s + s.query.output.head + } + newExpression.asInstanceOf[E] + } + + /** + * Construct a new child plan by left joining the given subqueries to a base plan. + */ + private def constructLeftJoins( + child: LogicalPlan, + subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { + subqueries.foldLeft(child) { + case (currentChild, ScalarSubquery(query, conditions, _)) => + Project( + currentChild.output :+ query.output.head, + Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + } + } + + /** + * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar + * subqueries. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, expressions, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + // We currently only allow correlated subqueries in an aggregate if they are part of the + // grouping expressions. As a result we need to replace all the scalar subqueries in the + // grouping expressions by their result. + val newGrouping = grouping.map { e => + subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e) + } + Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + } else { + a + } + case p @ Project(expressions, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + Project(newExpressions, constructLeftJoins(child, subqueries)) + } else { + p + } + case f @ Filter(condition, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) + if (subqueries.nonEmpty) { + Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) + } else { + f + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 830a7ac77d..7b4615db06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -109,7 +109,7 @@ case class Filter(condition: Expression, child: LogicalPlan) override protected def validConstraints: Set[Expression] = { val predicates = splitConjunctivePredicates(condition) - .filterNot(PredicateSubquery.hasPredicateSubquery) + .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 10bff3d6d8..2e88f61d49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -111,7 +111,8 @@ class AnalysisErrorSuite extends AnalysisTest { "scalar subquery with 2 columns", testRelation.select( (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), - "Scalar subquery must return only one column, but got 2" :: Nil) + "The number of columns in the subquery (2)" :: + "does not match the required number of columns (1)":: Nil) errorTest( "scalar subquery with no column", @@ -499,12 +500,4 @@ class AnalysisErrorSuite extends AnalysisTest { LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) } - - test("Correlated Scalar Subquery") { - val a = AttributeReference("a", IntegerType)() - val b = AttributeReference("b", IntegerType)() - val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) - val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), LocalRelation(a)) - assertAnalysisError(plan, "Correlated scalar subqueries are not supported." :: Nil) - } } |