From 6ce008ba46aa1fc8a5c222ce0f25a6d81f53588e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 24 Mar 2016 00:51:31 +0800 Subject: [SPARK-13549][SQL] Refactor the Optimizer Rule CollapseProject #### What changes were proposed in this pull request? The PR https://github.com/apache/spark/pull/10541 changed the rule `CollapseProject` by enabling collapsing `Project` into `Aggregate`. It leaves a to-do item to remove the duplicate code. This PR is to finish this to-do item. Also added a test case for covering this change. #### How was this patch tested? Added a new test case. liancheng Could you check if the code refactoring is fine? Thanks! Author: gatorsmile Closes #11427 from gatorsmile/collapseProjectRefactor. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 101 +++++++++------------ .../catalyst/optimizer/CollapseProjectSuite.scala | 26 +++++- 2 files changed, 70 insertions(+), 57 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0840d46e4e..4cfdcf95cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -417,68 +417,57 @@ object ColumnPruning extends Rule[LogicalPlan] { object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p @ Project(projectList1, Project(projectList2, child)) => - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliasMap = AttributeMap(projectList2.collect { - case a: Alias => (a.toAttribute, a) - }) - - // We only collapse these two Projects if their overlapped expressions are all - // deterministic. - val hasNondeterministic = projectList1.exists(_.collect { - case a: Attribute if aliasMap.contains(a) => aliasMap(a).child - }.exists(!_.deterministic)) - - if (hasNondeterministic) { + case p1 @ Project(_, p2: Project) => + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { + p1 + } else { + p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) + } + case p @ Project(_, agg: Aggregate) => + if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p } else { - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // TODO: Fix TransformBase to avoid the cast below. - val substitutedProjection = projectList1.map(_.transform { - case a: Attribute => aliasMap.getOrElse(a, a) - }).asInstanceOf[Seq[NamedExpression]] - // collapse 2 projects may introduce unnecessary Aliases, trim them here. - val cleanedProjection = substitutedProjection.map(p => - CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] - ) - Project(cleanedProjection, child) + agg.copy(aggregateExpressions = buildCleanedProjectList( + p.projectList, agg.aggregateExpressions)) } + } - // TODO Eliminate duplicate code - // This clause is identical to the one above except that the inner operator is an `Aggregate` - // rather than a `Project`. - case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) => - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliasMap = AttributeMap(projectList2.collect { - case a: Alias => (a.toAttribute, a) - }) + private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { + AttributeMap(projectList.collect { + case a: Alias => a.toAttribute -> a + }) + } - // We only collapse these two Projects if their overlapped expressions are all - // deterministic. - val hasNondeterministic = projectList1.exists(_.collect { - case a: Attribute if aliasMap.contains(a) => aliasMap(a).child - }.exists(!_.deterministic)) + private def haveCommonNonDeterministicOutput( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { + // Create a map of Aliases to their values from the lower projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliases = collectAliases(lower) + + // Collapse upper and lower Projects if and only if their overlapped expressions are all + // deterministic. + upper.exists(_.collect { + case a: Attribute if aliases.contains(a) => aliases(a).child + }.exists(!_.deterministic)) + } - if (hasNondeterministic) { - p - } else { - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // TODO: Fix TransformBase to avoid the cast below. - val substitutedProjection = projectList1.map(_.transform { - case a: Attribute => aliasMap.getOrElse(a, a) - }).asInstanceOf[Seq[NamedExpression]] - // collapse 2 projects may introduce unnecessary Aliases, trim them here. - val cleanedProjection = substitutedProjection.map(p => - CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] - ) - agg.copy(aggregateExpressions = cleanedProjection) - } + private def buildCleanedProjectList( + upper: Seq[NamedExpression], + lower: Seq[NamedExpression]): Seq[NamedExpression] = { + // Create a map of Aliases to their values from the lower projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliases = collectAliases(lower) + + // Substitute any attributes that are produced by the lower projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + val rewrittenUpper = upper.map(_.transform { + case a: Attribute => aliases.getOrElse(a, a) + }) + // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. + rewrittenUpper.map { p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 833f054659..587437e9aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -29,7 +29,7 @@ class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", FixedPoint(10), EliminateSubqueryAliases) :: - Batch("CollapseProject", Once, CollapseProject) :: Nil + Batch("CollapseProject", Once, CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int) @@ -95,4 +95,28 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("collapse project into aggregate") { + val query = testRelation + .groupBy('a, 'b)(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .groupBy('a, 'b)(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse common nondeterministic project and aggregate") { + val query = testRelation + .groupBy('a)('a, Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } } -- cgit v1.2.3