aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-09-29 14:30:23 -0700
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-29 14:30:23 -0700
commit566d7f28275f90f7b9bed6a75e90989ad0c59931 (patch)
treedd8f981ef27e3f14d2fa4d15c344659b3bd62130
parentfe33121a53384811a8e094ab6c05dc85b7c7ca87 (diff)
downloadspark-566d7f28275f90f7b9bed6a75e90989ad0c59931.tar.gz
spark-566d7f28275f90f7b9bed6a75e90989ad0c59931.tar.bz2
spark-566d7f28275f90f7b9bed6a75e90989ad0c59931.zip
[SPARK-17653][SQL] Remove unnecessary distincts in multiple unions
## What changes were proposed in this pull request? Currently for `Union [Distinct]`, a `Distinct` operator is necessary to be on the top of `Union`. Once there are adjacent `Union [Distinct]`, there will be multiple `Distinct` in the query plan. E.g., For a query like: select 1 a union select 2 b union select 3 c Before this patch, its physical plan looks like: *HashAggregate(keys=[a#13], functions=[]) +- Exchange hashpartitioning(a#13, 200) +- *HashAggregate(keys=[a#13], functions=[]) +- Union :- *HashAggregate(keys=[a#13], functions=[]) : +- Exchange hashpartitioning(a#13, 200) : +- *HashAggregate(keys=[a#13], functions=[]) : +- Union : :- *Project [1 AS a#13] : : +- Scan OneRowRelation[] : +- *Project [2 AS b#14] : +- Scan OneRowRelation[] +- *Project [3 AS c#15] +- Scan OneRowRelation[] Only the top distinct should be necessary. After this patch, the physical plan looks like: *HashAggregate(keys=[a#221], functions=[], output=[a#221]) +- Exchange hashpartitioning(a#221, 5) +- *HashAggregate(keys=[a#221], functions=[], output=[a#221]) +- Union :- *Project [1 AS a#221] : +- Scan OneRowRelation[] :- *Project [2 AS b#222] : +- Scan OneRowRelation[] +- *Project [3 AS c#223] +- Scan OneRowRelation[] ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #15238 from viirya/remove-extra-distinct-union.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala27
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala68
3 files changed, 89 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 4952ba3b2b..9df8ce1fa3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.annotation.tailrec
import scala.collection.immutable.HashSet
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.api.java.function.FilterFunction
@@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
-import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
+import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -579,8 +580,25 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
object CombineUnions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Unions(children) => Union(children)
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+ case u: Union => flattenUnion(u, false)
+ case Distinct(u: Union) => Distinct(flattenUnion(u, true))
+ }
+
+ private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = {
+ val stack = mutable.Stack[LogicalPlan](union)
+ val flattened = mutable.ArrayBuffer.empty[LogicalPlan]
+ while (stack.nonEmpty) {
+ stack.pop() match {
+ case Distinct(Union(children)) if flattenDistinct =>
+ stack.pushAll(children.reverse)
+ case Union(children) =>
+ stack.pushAll(children.reverse)
+ case child =>
+ flattened += child
+ }
+ }
+ Union(flattened)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 41cabb8cb3..bdae56881b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -188,33 +188,6 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
}
}
-
-/**
- * A pattern that collects all adjacent unions and returns their children as a Seq.
- */
-object Unions {
- def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match {
- case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan]))
- case _ => None
- }
-
- // Doing a depth-first tree traversal to combine all the union children.
- @tailrec
- private def collectUnionChildren(
- plans: mutable.Stack[LogicalPlan],
- children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
- if (plans.isEmpty) children
- else {
- plans.pop match {
- case Union(grandchildren) =>
- grandchildren.reverseMap(plans.push(_))
- collectUnionChildren(plans, children)
- case other => collectUnionChildren(plans, children :+ other)
- }
- }
- }
-}
-
/**
* An extractor used when planning the physical execution of an aggregation. Compared with a logical
* aggregation, the following transformations are performed:
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index 7227706ab2..21b7f49e14 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -76,4 +77,71 @@ class SetOperationSuite extends PlanTest {
testRelation3.select('g) :: Nil).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}
+
+ test("Remove unnecessary distincts in multiple unions") {
+ val query1 = OneRowRelation
+ .select(Literal(1).as('a))
+ val query2 = OneRowRelation
+ .select(Literal(2).as('b))
+ val query3 = OneRowRelation
+ .select(Literal(3).as('c))
+
+ // D - U - D - U - query1
+ // | |
+ // query3 query2
+ val unionQuery1 = Distinct(Union(Distinct(Union(query1, query2)), query3)).analyze
+ val optimized1 = Optimize.execute(unionQuery1)
+ val distinctUnionCorrectAnswer1 =
+ Distinct(Union(query1 :: query2 :: query3 :: Nil)).analyze
+ comparePlans(distinctUnionCorrectAnswer1, optimized1)
+
+ // query1
+ // |
+ // D - U - U - query2
+ // |
+ // D - U - query2
+ // |
+ // query3
+ val unionQuery2 = Distinct(Union(Union(query1, query2),
+ Distinct(Union(query2, query3)))).analyze
+ val optimized2 = Optimize.execute(unionQuery2)
+ val distinctUnionCorrectAnswer2 =
+ Distinct(Union(query1 :: query2 :: query2 :: query3 :: Nil)).analyze
+ comparePlans(distinctUnionCorrectAnswer2, optimized2)
+ }
+
+ test("Keep necessary distincts in multiple unions") {
+ val query1 = OneRowRelation
+ .select(Literal(1).as('a))
+ val query2 = OneRowRelation
+ .select(Literal(2).as('b))
+ val query3 = OneRowRelation
+ .select(Literal(3).as('c))
+ val query4 = OneRowRelation
+ .select(Literal(4).as('d))
+
+ // U - D - U - query1
+ // | |
+ // query3 query2
+ val unionQuery1 = Union(Distinct(Union(query1, query2)), query3).analyze
+ val optimized1 = Optimize.execute(unionQuery1)
+ val distinctUnionCorrectAnswer1 =
+ Union(Distinct(Union(query1 :: query2 :: Nil)) :: query3 :: Nil).analyze
+ comparePlans(distinctUnionCorrectAnswer1, optimized1)
+
+ // query1
+ // |
+ // U - D - U - query2
+ // |
+ // D - U - query3
+ // |
+ // query4
+ val unionQuery2 =
+ Union(Distinct(Union(query1, query2)), Distinct(Union(query3, query4))).analyze
+ val optimized2 = Optimize.execute(unionQuery2)
+ val distinctUnionCorrectAnswer2 =
+ Union(Distinct(Union(query1 :: query2 :: Nil)),
+ Distinct(Union(query3 :: query4 :: Nil))).analyze
+ comparePlans(distinctUnionCorrectAnswer2, optimized2)
+ }
}