aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-05-02 16:32:31 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-02 16:32:31 -0700
commitf362363d148e2df4549fed5c3fd1cf20d0848fd0 (patch)
treeae72eececa383e88ed8790acd98896c4fe52314d /sql/catalyst/src
parent917d05f43bddc1728735979fe7e62fe631b35e6f (diff)
downloadspark-f362363d148e2df4549fed5c3fd1cf20d0848fd0.tar.gz
spark-f362363d148e2df4549fed5c3fd1cf20d0848fd0.tar.bz2
spark-f362363d148e2df4549fed5c3fd1cf20d0848fd0.zip
[SPARK-14785] [SQL] Support correlated scalar subqueries
## What changes were proposed in this pull request? In this PR we add support for correlated scalar subqueries. An example of such a query is: ```SQL select * from tbl1 a where a.value > (select max(value) from tbl2 b where b.key = a.key) ``` The implementation adds the `RewriteCorrelatedScalarSubquery` rule to the Optimizer. This rule plans these subqueries using `LEFT OUTER` joins. It currently supports rewrites for `Project`, `Aggregate` & `Filter` logical plans. I could not find a well defined semantics for the use of scalar subqueries in an `Aggregate`. The current implementation currently evaluates the scalar subquery *before* aggregation. This means that you either have to make scalar subquery part of the grouping expression, or that you have to aggregate it further on. I am open to suggestions on this. The implementation currently forces the uniqueness of a scalar subquery by enforcing that it is aggregated and that the resulting column is wrapped in an `AggregateExpression`. ## How was this patch tested? Added tests to `SubquerySuite`. Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12822 from hvanhovell/SPARK-14785.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala82
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala11
6 files changed, 148 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2f8ab3f435..59af5b7095 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1081,10 +1081,10 @@ class Analyzer(
// Step 2: Pull out the predicates if the plan is resolved.
if (current.resolved) {
// Make sure the resolved query has the required number of output columns. This is only
- // needed for IN expressions.
+ // needed for Scalar and IN subqueries.
if (requiredColumns > 0 && requiredColumns != current.output.size) {
- failAnalysis(s"The number of fields in the value ($requiredColumns) does not " +
- s"match with the number of columns in the subquery (${current.output.size})")
+ failAnalysis(s"The number of columns in the subquery (${current.output.size}) " +
+ s"does not match the required number of columns ($requiredColumns)")
}
// Pullout predicates and construct a new plan.
f.tupled(rewriteSubQuery(current, plans))
@@ -1099,8 +1099,11 @@ class Analyzer(
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan transformExpressions {
+ case s @ ScalarSubquery(sub, conditions, exprId)
+ if sub.resolved && conditions.isEmpty && sub.output.size != 1 =>
+ failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}")
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
- resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
+ resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, exprId) =>
resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId))
case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6e3a14dfb9..800bf01abd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin}
+import org.apache.spark.sql.catalyst.plans.UsingJoin
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -60,9 +60,6 @@ trait CheckAnalysis extends PredicateHelper {
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")
- case ScalarSubquery(_, conditions, _) if conditions.nonEmpty =>
- failAnalysis("Correlated scalar subqueries are not supported.")
-
case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
@@ -104,6 +101,36 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis(s"Window specification $s is not valid because $m")
case None => w
}
+
+ case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
+ // Make sure we are using equi-joins.
+ conditions.foreach {
+ case _: EqualTo | _: EqualNullSafe => // ok
+ case e => failAnalysis(
+ s"The correlated scalar subquery can only contain equality predicates: $e")
+ }
+
+ // Make sure correlated scalar subqueries contain one row for every outer row by
+ // enforcing that they are aggregates which contain exactly one aggregate expressions.
+ // The analyzer has already checked that subquery contained only one output column, and
+ // added all the grouping expressions to the aggregate.
+ def checkAggregate(a: Aggregate): Unit = {
+ val aggregates = a.expressions.flatMap(_.collect {
+ case a: AggregateExpression => a
+ })
+ if (aggregates.isEmpty) {
+ failAnalysis("The output of a correlated scalar subquery must be aggregated")
+ }
+ }
+
+ query match {
+ case a: Aggregate => checkAggregate(a)
+ case Filter(_, a: Aggregate) => checkAggregate(a)
+ case Project(_, a: Aggregate) => checkAggregate(a)
+ case Project(_, Filter(_, a: Aggregate)) => checkAggregate(a)
+ case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail")
+ }
+ s
}
operator match {
@@ -220,6 +247,13 @@ trait CheckAnalysis extends PredicateHelper {
| but one table has '${firstError.output.length}' columns and another table has
| '${s.children.head.output.length}' columns""".stripMargin)
+ case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
+ p match {
+ case _: Filter | _: Aggregate | _: Project => // Ok
+ case other => failAnalysis(
+ s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p")
+ }
+
case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index eed062f8bc..5001f9a41e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -44,6 +44,15 @@ abstract class SubqueryExpression extends Expression {
protected def conditionString: String = children.mkString("[", " && ", "]")
}
+object SubqueryExpression {
+ def hasCorrelatedSubquery(e: Expression): Boolean = {
+ e.find {
+ case e: SubqueryExpression if e.children.nonEmpty => true
+ case _ => false
+ }.isDefined
+ }
+}
+
/**
* A subquery that will return only one row and one column. This will be converted into a physical
* scalar subquery during planning.
@@ -55,28 +64,26 @@ case class ScalarSubquery(
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {
-
- override def plan: LogicalPlan = SubqueryAlias(toString, query)
-
override lazy val resolved: Boolean = childrenResolved && query.resolved
-
- override def dataType: DataType = query.schema.fields.head.dataType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (query.schema.length != 1) {
- TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
- query.schema.length.toString)
- } else {
- TypeCheckResult.TypeCheckSuccess
- }
+ override lazy val references: AttributeSet = {
+ if (query.resolved) super.references -- query.outputSet
+ else super.references
}
-
+ override def dataType: DataType = query.schema.fields.head.dataType
override def foldable: Boolean = false
override def nullable: Boolean = true
-
+ override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan)
+ override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
+}
- override def toString: String = s"subquery#${exprId.id} $conditionString"
+object ScalarSubquery {
+ def hasCorrelatedScalarSubquery(e: Expression): Boolean = {
+ e.find {
+ case e: ScalarSubquery if e.children.nonEmpty => true
+ case _ => false
+ }.isDefined
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e1c969f50f..a3ab89dc71 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.annotation.tailrec
import scala.collection.immutable.HashSet
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
@@ -100,6 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
+ RewriteCorrelatedScalarSubquery,
EliminateSerialization) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
@@ -1081,7 +1083,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
assert(input.size >= 2)
if (input.size == 2) {
val (joinConditions, others) = conditions.partition(
- e => !PredicateSubquery.hasPredicateSubquery(e))
+ e => !SubqueryExpression.hasCorrelatedSubquery(e))
val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And))
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
@@ -1101,7 +1103,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(
- e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e))
+ e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And))
// should not have reference to same logical plan
@@ -1134,7 +1136,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper {
* Returns whether the expression returns null or false when all inputs are nulls.
*/
private def canFilterOutNull(e: Expression): Boolean = {
- if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false
+ if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false
val attributes = e.references.toSeq
val emptyRow = new GenericInternalRow(attributes.length)
val v = BindReferences.bindReference(e, attributes).eval(emptyRow)
@@ -1203,7 +1205,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
split(splitConjunctivePredicates(filterCondition), left, right)
-
joinType match {
case Inner =>
// push down the single side `where` condition into respective sides
@@ -1212,7 +1213,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val (newJoinConditions, others) =
- commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e))
+ commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
val join = Join(newLeft, newRight, Inner, newJoinCond)
@@ -1573,3 +1574,74 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}
}
+
+/**
+ * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins.
+ */
+object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
+ /**
+ * Extract all correlated scalar subqueries from an expression. The subqueries are collected using
+ * the given collector. The expression is rewritten and returned.
+ */
+ private def extractCorrelatedScalarSubqueries[E <: Expression](
+ expression: E,
+ subqueries: ArrayBuffer[ScalarSubquery]): E = {
+ val newExpression = expression transform {
+ case s: ScalarSubquery if s.children.nonEmpty =>
+ subqueries += s
+ s.query.output.head
+ }
+ newExpression.asInstanceOf[E]
+ }
+
+ /**
+ * Construct a new child plan by left joining the given subqueries to a base plan.
+ */
+ private def constructLeftJoins(
+ child: LogicalPlan,
+ subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
+ subqueries.foldLeft(child) {
+ case (currentChild, ScalarSubquery(query, conditions, _)) =>
+ Project(
+ currentChild.output :+ query.output.head,
+ Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
+ }
+ }
+
+ /**
+ * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
+ * subqueries.
+ */
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case a @ Aggregate(grouping, expressions, child) =>
+ val subqueries = ArrayBuffer.empty[ScalarSubquery]
+ val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
+ if (subqueries.nonEmpty) {
+ // We currently only allow correlated subqueries in an aggregate if they are part of the
+ // grouping expressions. As a result we need to replace all the scalar subqueries in the
+ // grouping expressions by their result.
+ val newGrouping = grouping.map { e =>
+ subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e)
+ }
+ Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
+ } else {
+ a
+ }
+ case p @ Project(expressions, child) =>
+ val subqueries = ArrayBuffer.empty[ScalarSubquery]
+ val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
+ if (subqueries.nonEmpty) {
+ Project(newExpressions, constructLeftJoins(child, subqueries))
+ } else {
+ p
+ }
+ case f @ Filter(condition, child) =>
+ val subqueries = ArrayBuffer.empty[ScalarSubquery]
+ val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
+ if (subqueries.nonEmpty) {
+ Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
+ } else {
+ f
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 830a7ac77d..7b4615db06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -109,7 +109,7 @@ case class Filter(condition: Expression, child: LogicalPlan)
override protected def validConstraints: Set[Expression] = {
val predicates = splitConjunctivePredicates(condition)
- .filterNot(PredicateSubquery.hasPredicateSubquery)
+ .filterNot(SubqueryExpression.hasCorrelatedSubquery)
child.constraints.union(predicates.toSet)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 10bff3d6d8..2e88f61d49 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -111,7 +111,8 @@ class AnalysisErrorSuite extends AnalysisTest {
"scalar subquery with 2 columns",
testRelation.select(
(ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)),
- "Scalar subquery must return only one column, but got 2" :: Nil)
+ "The number of columns in the subquery (2)" ::
+ "does not match the required number of columns (1)":: Nil)
errorTest(
"scalar subquery with no column",
@@ -499,12 +500,4 @@ class AnalysisErrorSuite extends AnalysisTest {
LocalRelation(a))
assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
}
-
- test("Correlated Scalar Subquery") {
- val a = AttributeReference("a", IntegerType)()
- val b = AttributeReference("b", IntegerType)()
- val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
- val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), LocalRelation(a))
- assertAnalysisError(plan, "Correlated scalar subqueries are not supported." :: Nil)
- }
}