aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-03-24 00:51:31 +0800
committerCheng Lian <lian@databricks.com>2016-03-24 00:51:31 +0800
commit6ce008ba46aa1fc8a5c222ce0f25a6d81f53588e (patch)
tree8d0f2ebe3aea31f2a36c9e9a85e2ac15ab526841
parentcde086cb2a9a85406fc18d8e63e46425f614c15f (diff)
downloadspark-6ce008ba46aa1fc8a5c222ce0f25a6d81f53588e.tar.gz
spark-6ce008ba46aa1fc8a5c222ce0f25a6d81f53588e.tar.bz2
spark-6ce008ba46aa1fc8a5c222ce0f25a6d81f53588e.zip
[SPARK-13549][SQL] Refactor the Optimizer Rule CollapseProject
#### What changes were proposed in this pull request? The PR https://github.com/apache/spark/pull/10541 changed the rule `CollapseProject` by enabling collapsing `Project` into `Aggregate`. It leaves a to-do item to remove the duplicate code. This PR is to finish this to-do item. Also added a test case for covering this change. #### How was this patch tested? Added a new test case. liancheng Could you check if the code refactoring is fine? Thanks! Author: gatorsmile <gatorsmile@gmail.com> Closes #11427 from gatorsmile/collapseProjectRefactor.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala101
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala26
2 files changed, 70 insertions, 57 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0840d46e4e..4cfdcf95cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -417,68 +417,57 @@ object ColumnPruning extends Rule[LogicalPlan] {
object CollapseProject extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case p @ Project(projectList1, Project(projectList2, child)) =>
- // Create a map of Aliases to their values from the child projection.
- // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
- val aliasMap = AttributeMap(projectList2.collect {
- case a: Alias => (a.toAttribute, a)
- })
-
- // We only collapse these two Projects if their overlapped expressions are all
- // deterministic.
- val hasNondeterministic = projectList1.exists(_.collect {
- case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
- }.exists(!_.deterministic))
-
- if (hasNondeterministic) {
+ case p1 @ Project(_, p2: Project) =>
+ if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
+ p1
+ } else {
+ p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
+ }
+ case p @ Project(_, agg: Aggregate) =>
+ if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
p
} else {
- // Substitute any attributes that are produced by the child projection, so that we safely
- // eliminate it.
- // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
- // TODO: Fix TransformBase to avoid the cast below.
- val substitutedProjection = projectList1.map(_.transform {
- case a: Attribute => aliasMap.getOrElse(a, a)
- }).asInstanceOf[Seq[NamedExpression]]
- // collapse 2 projects may introduce unnecessary Aliases, trim them here.
- val cleanedProjection = substitutedProjection.map(p =>
- CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
- )
- Project(cleanedProjection, child)
+ agg.copy(aggregateExpressions = buildCleanedProjectList(
+ p.projectList, agg.aggregateExpressions))
}
+ }
- // TODO Eliminate duplicate code
- // This clause is identical to the one above except that the inner operator is an `Aggregate`
- // rather than a `Project`.
- case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) =>
- // Create a map of Aliases to their values from the child projection.
- // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
- val aliasMap = AttributeMap(projectList2.collect {
- case a: Alias => (a.toAttribute, a)
- })
+ private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
+ AttributeMap(projectList.collect {
+ case a: Alias => a.toAttribute -> a
+ })
+ }
- // We only collapse these two Projects if their overlapped expressions are all
- // deterministic.
- val hasNondeterministic = projectList1.exists(_.collect {
- case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
- }.exists(!_.deterministic))
+ private def haveCommonNonDeterministicOutput(
+ upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
+ // Create a map of Aliases to their values from the lower projection.
+ // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
+ val aliases = collectAliases(lower)
+
+ // Collapse upper and lower Projects if and only if their overlapped expressions are all
+ // deterministic.
+ upper.exists(_.collect {
+ case a: Attribute if aliases.contains(a) => aliases(a).child
+ }.exists(!_.deterministic))
+ }
- if (hasNondeterministic) {
- p
- } else {
- // Substitute any attributes that are produced by the child projection, so that we safely
- // eliminate it.
- // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
- // TODO: Fix TransformBase to avoid the cast below.
- val substitutedProjection = projectList1.map(_.transform {
- case a: Attribute => aliasMap.getOrElse(a, a)
- }).asInstanceOf[Seq[NamedExpression]]
- // collapse 2 projects may introduce unnecessary Aliases, trim them here.
- val cleanedProjection = substitutedProjection.map(p =>
- CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
- )
- agg.copy(aggregateExpressions = cleanedProjection)
- }
+ private def buildCleanedProjectList(
+ upper: Seq[NamedExpression],
+ lower: Seq[NamedExpression]): Seq[NamedExpression] = {
+ // Create a map of Aliases to their values from the lower projection.
+ // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
+ val aliases = collectAliases(lower)
+
+ // Substitute any attributes that are produced by the lower projection, so that we safely
+ // eliminate it.
+ // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
+ val rewrittenUpper = upper.map(_.transform {
+ case a: Attribute => aliases.getOrElse(a, a)
+ })
+ // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
+ rewrittenUpper.map { p =>
+ CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
index 833f054659..587437e9aa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
@@ -29,7 +29,7 @@ class CollapseProjectSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", FixedPoint(10), EliminateSubqueryAliases) ::
- Batch("CollapseProject", Once, CollapseProject) :: Nil
+ Batch("CollapseProject", Once, CollapseProject) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int)
@@ -95,4 +95,28 @@ class CollapseProjectSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("collapse project into aggregate") {
+ val query = testRelation
+ .groupBy('a, 'b)(('a + 1).as('a_plus_1), 'b)
+ .select('a_plus_1, ('b + 1).as('b_plus_1))
+
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer = testRelation
+ .groupBy('a, 'b)(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("do not collapse common nondeterministic project and aggregate") {
+ val query = testRelation
+ .groupBy('a)('a, Rand(10).as('rand))
+ .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))
+
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = query.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}