From 0b0c8b95e3594db36d87ef0e59a30eefe8508ac1 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 17 Aug 2016 07:03:24 -0700 Subject: [SPARK-17106] [SQL] Simplify the SubqueryExpression interface ## What changes were proposed in this pull request? The current subquery expression interface contains a little bit of technical debt in the form of a few different access paths to get and set the query contained by the expression. This is confusing to anyone who goes over this code. This PR unifies these access paths. ## How was this patch tested? (Existing tests) Author: Herman van Hovell Closes #14685 from hvanhovell/SPARK-17106. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +- .../spark/sql/catalyst/expressions/subquery.scala | 60 ++++++++++------------ .../spark/sql/catalyst/optimizer/Optimizer.scala | 6 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../org/apache/spark/sql/catalyst/SQLBuilder.scala | 2 +- .../org/apache/spark/sql/execution/subquery.scala | 49 ++++++------------ .../scala/org/apache/spark/sql/QueryTest.scala | 4 +- .../execution/benchmark/TPCDSQueryBenchmark.scala | 1 - 8 files changed, 56 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bd4c19181f..f540816366 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -146,7 +146,7 @@ class Analyzer( // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. other transformExpressions { case e: SubqueryExpression => - e.withNewPlan(substituteCTE(e.query, cteRelations)) + e.withNewPlan(substituteCTE(e.plan, cteRelations)) } } } @@ -1091,7 +1091,7 @@ class Analyzer( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { // Step 1: Resolve the outer expressions. var previous: LogicalPlan = null - var current = e.query + var current = e.plan do { // Try to resolve the subquery plan using the regular analyzer. previous = current diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index ddbe937cba..e2e7d98e33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -17,33 +17,33 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ /** - * An interface for subquery that is used in expressions. + * An interface for expressions that contain a [[QueryPlan]]. */ -abstract class SubqueryExpression extends Expression { +abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { /** The id of the subquery expression. */ def exprId: ExprId - /** The logical plan of the query. */ - def query: LogicalPlan + /** The plan being wrapped in the query. */ + def plan: T - /** - * Either a logical plan or a physical plan. The generated tree string (explain output) uses this - * field to explain the subquery. - */ - def plan: QueryPlan[_] - - /** Updates the query with new logical plan. */ - def withNewPlan(plan: LogicalPlan): SubqueryExpression + /** Updates the expression with a new plan. */ + def withNewPlan(plan: T): PlanExpression[T] protected def conditionString: String = children.mkString("[", " && ", "]") } +/** + * A base interface for expressions that contain a [[LogicalPlan]]. + */ +abstract class SubqueryExpression extends PlanExpression[LogicalPlan] { + override def withNewPlan(plan: LogicalPlan): SubqueryExpression +} + object SubqueryExpression { def hasCorrelatedSubquery(e: Expression): Boolean = { e.find { @@ -60,20 +60,19 @@ object SubqueryExpression { * Note: `exprId` is used to have a unique name in explain string output. */ case class ScalarSubquery( - query: LogicalPlan, + plan: LogicalPlan, children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Unevaluable { - override lazy val resolved: Boolean = childrenResolved && query.resolved + override lazy val resolved: Boolean = childrenResolved && plan.resolved override lazy val references: AttributeSet = { - if (query.resolved) super.references -- query.outputSet + if (plan.resolved) super.references -- plan.outputSet else super.references } - override def dataType: DataType = query.schema.fields.head.dataType + override def dataType: DataType = plan.schema.fields.head.dataType override def foldable: Boolean = false override def nullable: Boolean = true - override def plan: LogicalPlan = SubqueryAlias(toString, query, None) - override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan) + override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" } @@ -92,19 +91,18 @@ object ScalarSubquery { * be rewritten into a left semi/anti join during analysis. */ case class PredicateSubquery( - query: LogicalPlan, + plan: LogicalPlan, children: Seq[Expression] = Seq.empty, nullAware: Boolean = false, exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = childrenResolved && query.resolved - override lazy val references: AttributeSet = super.references -- query.outputSet + override lazy val resolved = childrenResolved && plan.resolved + override lazy val references: AttributeSet = super.references -- plan.outputSet override def nullable: Boolean = nullAware - override def plan: LogicalPlan = SubqueryAlias(toString, query, None) - override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan) + override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan) override def semanticEquals(o: Expression): Boolean = o match { case p: PredicateSubquery => - query.sameResult(p.query) && nullAware == p.nullAware && + plan.sameResult(p.plan) && nullAware == p.nullAware && children.length == p.children.length && children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) case _ => false @@ -146,14 +144,13 @@ object PredicateSubquery { * FROM b) * }}} */ -case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) +case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Unevaluable { override lazy val resolved = false override def children: Seq[Expression] = Seq.empty override def dataType: DataType = ArrayType(NullType) override def nullable: Boolean = false - override def withNewPlan(plan: LogicalPlan): ListQuery = copy(query = plan) - override def plan: LogicalPlan = SubqueryAlias(toString, query, None) + override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def toString: String = s"list#${exprId.id}" } @@ -168,12 +165,11 @@ case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExp * WHERE b.id = a.id) * }}} */ -case class Exists(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) +case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Predicate with Unevaluable { override lazy val resolved = false override def children: Seq[Expression] = Seq.empty override def nullable: Boolean = false - override def withNewPlan(plan: LogicalPlan): Exists = copy(query = plan) - override def plan: LogicalPlan = SubqueryAlias(toString, query, None) + override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) override def toString: String = s"exists#${exprId.id}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f97a78b411..aa15f4a823 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -127,7 +127,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) object OptimizeSubqueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => - s.withNewPlan(Optimizer.this.execute(s.query)) + s.withNewPlan(Optimizer.this.execute(s.plan)) } } } @@ -1814,7 +1814,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { val newExpression = expression transform { case s: ScalarSubquery if s.children.nonEmpty => subqueries += s - s.query.output.head + s.plan.output.head } newExpression.asInstanceOf[E] } @@ -2029,7 +2029,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // grouping expressions. As a result we need to replace all the scalar subqueries in the // grouping expressions by their result. val newGrouping = grouping.map { e => - subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e) + subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e) } Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index becf6945a2..8ee31f42ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -263,7 +263,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * All the subqueries of current plan. */ def subqueries: Seq[PlanType] = { - expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]}) + expressions.flatMap(_.collect { + case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] + }) } override protected def innerChildren: Seq[QueryPlan[_]] = subqueries diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index ff8e0f2642..0f51aa58d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -80,7 +80,7 @@ class SQLBuilder private ( try { val replaced = finalPlan.transformAllExpressions { case s: SubqueryExpression => - val query = new SQLBuilder(s.query, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL + val query = new SQLBuilder(s.plan, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL val sql = s match { case _: ListQuery => query case _: Exists => s"EXISTS($query)" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index c730bee6ae..730ca27f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -22,9 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, StructType} @@ -32,18 +31,7 @@ import org.apache.spark.sql.types.{BooleanType, DataType, StructType} /** * The base class for subquery that is used in SparkPlan. */ -trait ExecSubqueryExpression extends SubqueryExpression { - - val executedPlan: SubqueryExec - def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression - - // does not have logical plan - override def query: LogicalPlan = throw new UnsupportedOperationException - override def withNewPlan(plan: LogicalPlan): SubqueryExpression = - throw new UnsupportedOperationException - - override def plan: SparkPlan = executedPlan - +abstract class ExecSubqueryExpression extends PlanExpression[SubqueryExec] { /** * Fill the expression with collected result from executed plan. */ @@ -56,30 +44,29 @@ trait ExecSubqueryExpression extends SubqueryExpression { * This is the physical copy of ScalarSubquery to be used inside SparkPlan. */ case class ScalarSubquery( - executedPlan: SubqueryExec, + plan: SubqueryExec, exprId: ExprId) extends ExecSubqueryExpression { - override def dataType: DataType = executedPlan.schema.fields.head.dataType + override def dataType: DataType = plan.schema.fields.head.dataType override def children: Seq[Expression] = Nil override def nullable: Boolean = true - override def toString: String = executedPlan.simpleString - - def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan) + override def toString: String = plan.simpleString + override def withNewPlan(query: SubqueryExec): ScalarSubquery = copy(plan = query) override def semanticEquals(other: Expression): Boolean = other match { - case s: ScalarSubquery => executedPlan.sameResult(executedPlan) + case s: ScalarSubquery => plan.sameResult(s.plan) case _ => false } // the first column in first row from `query`. - @volatile private var result: Any = null + @volatile private var result: Any = _ @volatile private var updated: Boolean = false def updateResult(): Unit = { val rows = plan.executeCollect() if (rows.length > 1) { - sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}") + sys.error(s"more than one row returned by a subquery used as an expression:\n$plan") } if (rows.length == 1) { assert(rows(0).numFields == 1, @@ -108,7 +95,7 @@ case class ScalarSubquery( */ case class InSubquery( child: Expression, - executedPlan: SubqueryExec, + plan: SubqueryExec, exprId: ExprId, private var result: Array[Any] = null, private var updated: Boolean = false) extends ExecSubqueryExpression { @@ -116,13 +103,11 @@ case class InSubquery( override def dataType: DataType = BooleanType override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = child.nullable - override def toString: String = s"$child IN ${executedPlan.name}" - - def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan) + override def toString: String = s"$child IN ${plan.name}" + override def withNewPlan(plan: SubqueryExec): InSubquery = copy(plan = plan) override def semanticEquals(other: Expression): Boolean = other match { - case in: InSubquery => child.semanticEquals(in.child) && - executedPlan.sameResult(in.executedPlan) + case in: InSubquery => child.semanticEquals(in.child) && plan.sameResult(in.plan) case _ => false } @@ -159,8 +144,8 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { ScalarSubquery( SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan), subquery.exprId) - case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) => - val executedPlan = new QueryExecution(sparkSession, plan).executedPlan + case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) => + val executedPlan = new QueryExecution(sparkSession, query).executedPlan InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) } } @@ -184,9 +169,9 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]()) val sameResult = sameSchema.find(_.sameResult(sub.plan)) if (sameResult.isDefined) { - sub.withExecutedPlan(sameResult.get) + sub.withNewPlan(sameResult.get) } else { - sameSchema += sub.executedPlan + sameSchema += sub.plan sub } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 304881d4a4..cff9d22d08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -292,7 +292,7 @@ abstract class QueryTest extends PlanTest { p.expressions.foreach { _.foreach { case s: SubqueryExpression => - s.query.foreach(collectData) + s.plan.foreach(collectData) case _ => } } @@ -334,7 +334,7 @@ abstract class QueryTest extends PlanTest { case p => p.transformExpressions { case s: SubqueryExpression => - s.withNewPlan(s.query.transformDown(renormalize)) + s.withNewPlan(s.plan.transformDown(renormalize)) } } val normalized2 = jsonBackPlan.transformDown(renormalize) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 957a1d6426..3988d9750b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Benchmark /** -- cgit v1.2.3