aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorXiao Li <gatorsmile@gmail.com>2017-03-08 09:36:01 -0800
committerXiao Li <gatorsmile@gmail.com>2017-03-08 09:36:01 -0800
commit9a6ac7226fd09d570cae08d0daea82d9bca189a0 (patch)
tree19d31d8de6fd26ad6f3168891d92b693e19fa802 /sql/catalyst
parent5f7d835d380c1a558a4a6d8366140cd96ee202eb (diff)
downloadspark-9a6ac7226fd09d570cae08d0daea82d9bca189a0.tar.gz
spark-9a6ac7226fd09d570cae08d0daea82d9bca189a0.tar.bz2
spark-9a6ac7226fd09d570cae08d0daea82d9bca189a0.zip
[SPARK-19601][SQL] Fix CollapseRepartition rule to preserve shuffle-enabled Repartition
### What changes were proposed in this pull request? Observed by felixcheung in https://github.com/apache/spark/pull/16739, when users use the shuffle-enabled `repartition` API, they expect the partition they got should be the exact number they provided, even if they call shuffle-disabled `coalesce` later. Currently, `CollapseRepartition` rule does not consider whether shuffle is enabled or not. Thus, we got the following unexpected result. ```Scala val df = spark.range(0, 10000, 1, 5) val df2 = df.repartition(10) assert(df2.coalesce(13).rdd.getNumPartitions == 5) assert(df2.coalesce(7).rdd.getNumPartitions == 5) assert(df2.coalesce(3).rdd.getNumPartitions == 3) ``` This PR is to fix the issue. We preserve shuffle-enabled Repartition. ### How was this patch tested? Added a test case Author: Xiao Li <gatorsmile@gmail.com> Closes #16933 from gatorsmile/CollapseRepartition.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala153
4 files changed, 166 insertions, 38 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 0f0d90494f..35ca2a0aa5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -370,6 +370,9 @@ package object dsl {
def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)
+ def coalesce(num: Integer): LogicalPlan =
+ Repartition(num, shuffle = false, logicalPlan)
+
def repartition(num: Integer): LogicalPlan =
Repartition(num, shuffle = true, logicalPlan)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d5bbc6e8ac..caafa1c134 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -564,27 +564,23 @@ object CollapseProject extends Rule[LogicalPlan] {
}
/**
- * Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations
- * by keeping only the one.
- * 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]].
- * 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]].
- * 3. For a combination of [[Repartition]] and [[RepartitionByExpression]], collapse as a single
- * [[RepartitionByExpression]] with the expression and last number of partition.
+ * Combines adjacent [[RepartitionOperation]] operators
*/
object CollapseRepartition extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- // Case 1
- case Repartition(numPartitions, shuffle, Repartition(_, _, child)) =>
- Repartition(numPartitions, shuffle, child)
- // Case 2
- case RepartitionByExpression(exprs, RepartitionByExpression(_, child, _), numPartitions) =>
- RepartitionByExpression(exprs, child, numPartitions)
- // Case 3
- case Repartition(numPartitions, _, r: RepartitionByExpression) =>
- r.copy(numPartitions = numPartitions)
- // Case 3
- case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) =>
- RepartitionByExpression(exprs, child, numPartitions)
+ // Case 1: When a Repartition has a child of Repartition or RepartitionByExpression,
+ // 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child
+ // enables the shuffle. Returns the child node if the last numPartitions is bigger;
+ // otherwise, keep unchanged.
+ // 2) In the other cases, returns the top node with the child's child
+ case r @ Repartition(_, _, child: RepartitionOperation) => (r.shuffle, child.shuffle) match {
+ case (false, true) => if (r.numPartitions >= child.numPartitions) child else r
+ case _ => r.copy(child = child.child)
+ }
+ // Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression
+ // we can remove the child.
+ case r @ RepartitionByExpression(_, child: RepartitionOperation, _) =>
+ r.copy(child = child.child)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 70c5ed4b07..31b6ed48a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -843,15 +843,23 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
}
/**
+ * A base interface for [[RepartitionByExpression]] and [[Repartition]]
+ */
+abstract class RepartitionOperation extends UnaryNode {
+ def shuffle: Boolean
+ def numPartitions: Int
+ override def output: Seq[Attribute] = child.output
+}
+
+/**
* Returns a new RDD that has exactly `numPartitions` partitions. Differs from
* [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user
* asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer
* of the output requires some specific ordering or distribution of the data.
*/
case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
- extends UnaryNode {
+ extends RepartitionOperation {
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
- override def output: Seq[Attribute] = child.output
}
/**
@@ -863,12 +871,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
- numPartitions: Int) extends UnaryNode {
+ numPartitions: Int) extends RepartitionOperation {
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
override def maxRows: Option[Long] = child.maxRows
- override def output: Seq[Attribute] = child.output
+ override def shuffle: Boolean = true
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
index 8952c72fe4..59d2dc46f0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
@@ -32,47 +32,168 @@ class CollapseRepartitionSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int)
+
+ test("collapse two adjacent coalesces into one") {
+ // Always respects the top coalesces amd removes useless coalesce below coalesce
+ val query1 = testRelation
+ .coalesce(10)
+ .coalesce(20)
+ val query2 = testRelation
+ .coalesce(30)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.coalesce(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
test("collapse two adjacent repartitions into one") {
- val query = testRelation
+ // Always respects the top repartition amd removes useless repartition below repartition
+ val query1 = testRelation
+ .repartition(10)
+ .repartition(20)
+ val query2 = testRelation
+ .repartition(30)
+ .repartition(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.repartition(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
+ test("coalesce above repartition") {
+ // Remove useless coalesce above repartition
+ val query1 = testRelation
.repartition(10)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val correctAnswer1 = testRelation.repartition(10).analyze
+
+ comparePlans(optimized1, correctAnswer1)
+
+ // No change in this case
+ val query2 = testRelation
+ .repartition(30)
+ .coalesce(20)
+
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer2 = query2.analyze
+
+ comparePlans(optimized2, correctAnswer2)
+ }
+
+ test("repartition above coalesce") {
+ // Always respects the top repartition amd removes useless coalesce below repartition
+ val query1 = testRelation
+ .coalesce(10)
+ .repartition(20)
+ val query2 = testRelation
+ .coalesce(30)
.repartition(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.repartition(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
- test("collapse repartition and repartitionBy into one") {
- val query = testRelation
+ test("repartitionBy above repartition") {
+ // Always respects the top repartitionBy amd removes useless repartition
+ val query1 = testRelation
.repartition(10)
.distribute('a)(20)
+ val query2 = testRelation
+ .repartition(30)
+ .distribute('a)(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
- test("collapse repartitionBy and repartition into one") {
- val query = testRelation
+ test("repartitionBy above coalesce") {
+ // Always respects the top repartitionBy amd removes useless coalesce below repartition
+ val query1 = testRelation
+ .coalesce(10)
+ .distribute('a)(20)
+ val query2 = testRelation
+ .coalesce(30)
.distribute('a)(20)
- .repartition(10)
- val optimized = Optimize.execute(query.analyze)
- val correctAnswer = testRelation.distribute('a)(10).analyze
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
+ test("repartition above repartitionBy") {
+ // Always respects the top repartition amd removes useless distribute below repartition
+ val query1 = testRelation
+ .distribute('a)(10)
+ .repartition(20)
+ val query2 = testRelation
+ .distribute('a)(30)
+ .repartition(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.repartition(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+
+ }
+
+ test("coalesce above repartitionBy") {
+ // Remove useless coalesce above repartition
+ val query1 = testRelation
+ .distribute('a)(10)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val correctAnswer1 = testRelation.distribute('a)(10).analyze
+
+ comparePlans(optimized1, correctAnswer1)
+
+ // No change in this case
+ val query2 = testRelation
+ .distribute('a)(30)
+ .coalesce(20)
+
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer2 = query2.analyze
+
+ comparePlans(optimized2, correctAnswer2)
}
test("collapse two adjacent repartitionBys into one") {
- val query = testRelation
+ // Always respects the top repartitionBy
+ val query1 = testRelation
.distribute('b)(10)
.distribute('a)(20)
+ val query2 = testRelation
+ .distribute('b)(30)
+ .distribute('a)(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
}