diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2016-09-29 14:30:23 -0700 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-09-29 14:30:23 -0700 |
commit | 566d7f28275f90f7b9bed6a75e90989ad0c59931 (patch) | |
tree | dd8f981ef27e3f14d2fa4d15c344659b3bd62130 /sql/catalyst | |
parent | fe33121a53384811a8e094ab6c05dc85b7c7ca87 (diff) | |
download | spark-566d7f28275f90f7b9bed6a75e90989ad0c59931.tar.gz spark-566d7f28275f90f7b9bed6a75e90989ad0c59931.tar.bz2 spark-566d7f28275f90f7b9bed6a75e90989ad0c59931.zip |
[SPARK-17653][SQL] Remove unnecessary distincts in multiple unions
## What changes were proposed in this pull request?
Currently for `Union [Distinct]`, a `Distinct` operator is necessary to be on the top of `Union`. Once there are adjacent `Union [Distinct]`, there will be multiple `Distinct` in the query plan.
E.g.,
For a query like: select 1 a union select 2 b union select 3 c
Before this patch, its physical plan looks like:
*HashAggregate(keys=[a#13], functions=[])
+- Exchange hashpartitioning(a#13, 200)
+- *HashAggregate(keys=[a#13], functions=[])
+- Union
:- *HashAggregate(keys=[a#13], functions=[])
: +- Exchange hashpartitioning(a#13, 200)
: +- *HashAggregate(keys=[a#13], functions=[])
: +- Union
: :- *Project [1 AS a#13]
: : +- Scan OneRowRelation[]
: +- *Project [2 AS b#14]
: +- Scan OneRowRelation[]
+- *Project [3 AS c#15]
+- Scan OneRowRelation[]
Only the top distinct should be necessary.
After this patch, the physical plan looks like:
*HashAggregate(keys=[a#221], functions=[], output=[a#221])
+- Exchange hashpartitioning(a#221, 5)
+- *HashAggregate(keys=[a#221], functions=[], output=[a#221])
+- Union
:- *Project [1 AS a#221]
: +- Scan OneRowRelation[]
:- *Project [2 AS b#222]
: +- Scan OneRowRelation[]
+- *Project [3 AS c#223]
+- Scan OneRowRelation[]
## How was this patch tested?
Jenkins tests.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #15238 from viirya/remove-extra-distinct-union.
Diffstat (limited to 'sql/catalyst')
3 files changed, 89 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4952ba3b2b..9df8ce1fa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.java.function.FilterFunction @@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -579,8 +580,25 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe * Combines all adjacent [[Union]] operators into a single [[Union]]. */ object CombineUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Unions(children) => Union(children) + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case u: Union => flattenUnion(u, false) + case Distinct(u: Union) => Distinct(flattenUnion(u, true)) + } + + private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = { + val stack = mutable.Stack[LogicalPlan](union) + val flattened = mutable.ArrayBuffer.empty[LogicalPlan] + while (stack.nonEmpty) { + stack.pop() match { + case Distinct(Union(children)) if flattenDistinct => + stack.pushAll(children.reverse) + case Union(children) => + stack.pushAll(children.reverse) + case child => + flattened += child + } + } + Union(flattened) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 41cabb8cb3..bdae56881b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -188,33 +188,6 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { } } - -/** - * A pattern that collects all adjacent unions and returns their children as a Seq. - */ -object Unions { - def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan])) - case _ => None - } - - // Doing a depth-first tree traversal to combine all the union children. - @tailrec - private def collectUnionChildren( - plans: mutable.Stack[LogicalPlan], - children: Seq[LogicalPlan]): Seq[LogicalPlan] = { - if (plans.isEmpty) children - else { - plans.pop match { - case Union(grandchildren) => - grandchildren.reverseMap(plans.push(_)) - collectUnionChildren(plans, children) - case other => collectUnionChildren(plans, children :+ other) - } - } - } -} - /** * An extractor used when planning the physical execution of an aggregation. Compared with a logical * aggregation, the following transformations are performed: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 7227706ab2..21b7f49e14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -76,4 +77,71 @@ class SetOperationSuite extends PlanTest { testRelation3.select('g) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } + + test("Remove unnecessary distincts in multiple unions") { + val query1 = OneRowRelation + .select(Literal(1).as('a)) + val query2 = OneRowRelation + .select(Literal(2).as('b)) + val query3 = OneRowRelation + .select(Literal(3).as('c)) + + // D - U - D - U - query1 + // | | + // query3 query2 + val unionQuery1 = Distinct(Union(Distinct(Union(query1, query2)), query3)).analyze + val optimized1 = Optimize.execute(unionQuery1) + val distinctUnionCorrectAnswer1 = + Distinct(Union(query1 :: query2 :: query3 :: Nil)).analyze + comparePlans(distinctUnionCorrectAnswer1, optimized1) + + // query1 + // | + // D - U - U - query2 + // | + // D - U - query2 + // | + // query3 + val unionQuery2 = Distinct(Union(Union(query1, query2), + Distinct(Union(query2, query3)))).analyze + val optimized2 = Optimize.execute(unionQuery2) + val distinctUnionCorrectAnswer2 = + Distinct(Union(query1 :: query2 :: query2 :: query3 :: Nil)).analyze + comparePlans(distinctUnionCorrectAnswer2, optimized2) + } + + test("Keep necessary distincts in multiple unions") { + val query1 = OneRowRelation + .select(Literal(1).as('a)) + val query2 = OneRowRelation + .select(Literal(2).as('b)) + val query3 = OneRowRelation + .select(Literal(3).as('c)) + val query4 = OneRowRelation + .select(Literal(4).as('d)) + + // U - D - U - query1 + // | | + // query3 query2 + val unionQuery1 = Union(Distinct(Union(query1, query2)), query3).analyze + val optimized1 = Optimize.execute(unionQuery1) + val distinctUnionCorrectAnswer1 = + Union(Distinct(Union(query1 :: query2 :: Nil)) :: query3 :: Nil).analyze + comparePlans(distinctUnionCorrectAnswer1, optimized1) + + // query1 + // | + // U - D - U - query2 + // | + // D - U - query3 + // | + // query4 + val unionQuery2 = + Union(Distinct(Union(query1, query2)), Distinct(Union(query3, query4))).analyze + val optimized2 = Optimize.execute(unionQuery2) + val distinctUnionCorrectAnswer2 = + Union(Distinct(Union(query1 :: query2 :: Nil)), + Distinct(Union(query3 :: query4 :: Nil))).analyze + comparePlans(distinctUnionCorrectAnswer2, optimized2) + } } |