aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala153
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala9
7 files changed, 178 insertions, 49 deletions
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 620b633637..9735fe3201 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2592,8 +2592,8 @@ test_that("coalesce, repartition, numPartitions", {
df2 <- repartition(df1, 10)
expect_equal(getNumPartitions(df2), 10)
- expect_equal(getNumPartitions(coalesce(df2, 13)), 5)
- expect_equal(getNumPartitions(coalesce(df2, 7)), 5)
+ expect_equal(getNumPartitions(coalesce(df2, 13)), 10)
+ expect_equal(getNumPartitions(coalesce(df2, 7)), 7)
expect_equal(getNumPartitions(coalesce(df2, 3)), 3)
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 0f0d90494f..35ca2a0aa5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -370,6 +370,9 @@ package object dsl {
def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)
+ def coalesce(num: Integer): LogicalPlan =
+ Repartition(num, shuffle = false, logicalPlan)
+
def repartition(num: Integer): LogicalPlan =
Repartition(num, shuffle = true, logicalPlan)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d5bbc6e8ac..caafa1c134 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -564,27 +564,23 @@ object CollapseProject extends Rule[LogicalPlan] {
}
/**
- * Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations
- * by keeping only the one.
- * 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]].
- * 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]].
- * 3. For a combination of [[Repartition]] and [[RepartitionByExpression]], collapse as a single
- * [[RepartitionByExpression]] with the expression and last number of partition.
+ * Combines adjacent [[RepartitionOperation]] operators
*/
object CollapseRepartition extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- // Case 1
- case Repartition(numPartitions, shuffle, Repartition(_, _, child)) =>
- Repartition(numPartitions, shuffle, child)
- // Case 2
- case RepartitionByExpression(exprs, RepartitionByExpression(_, child, _), numPartitions) =>
- RepartitionByExpression(exprs, child, numPartitions)
- // Case 3
- case Repartition(numPartitions, _, r: RepartitionByExpression) =>
- r.copy(numPartitions = numPartitions)
- // Case 3
- case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) =>
- RepartitionByExpression(exprs, child, numPartitions)
+ // Case 1: When a Repartition has a child of Repartition or RepartitionByExpression,
+ // 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child
+ // enables the shuffle. Returns the child node if the last numPartitions is bigger;
+ // otherwise, keep unchanged.
+ // 2) In the other cases, returns the top node with the child's child
+ case r @ Repartition(_, _, child: RepartitionOperation) => (r.shuffle, child.shuffle) match {
+ case (false, true) => if (r.numPartitions >= child.numPartitions) child else r
+ case _ => r.copy(child = child.child)
+ }
+ // Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression
+ // we can remove the child.
+ case r @ RepartitionByExpression(_, child: RepartitionOperation, _) =>
+ r.copy(child = child.child)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 70c5ed4b07..31b6ed48a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -843,15 +843,23 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
}
/**
+ * A base interface for [[RepartitionByExpression]] and [[Repartition]]
+ */
+abstract class RepartitionOperation extends UnaryNode {
+ def shuffle: Boolean
+ def numPartitions: Int
+ override def output: Seq[Attribute] = child.output
+}
+
+/**
* Returns a new RDD that has exactly `numPartitions` partitions. Differs from
* [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user
* asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer
* of the output requires some specific ordering or distribution of the data.
*/
case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
- extends UnaryNode {
+ extends RepartitionOperation {
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
- override def output: Seq[Attribute] = child.output
}
/**
@@ -863,12 +871,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
- numPartitions: Int) extends UnaryNode {
+ numPartitions: Int) extends RepartitionOperation {
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
override def maxRows: Option[Long] = child.maxRows
- override def output: Seq[Attribute] = child.output
+ override def shuffle: Boolean = true
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
index 8952c72fe4..59d2dc46f0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
@@ -32,47 +32,168 @@ class CollapseRepartitionSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int)
+
+ test("collapse two adjacent coalesces into one") {
+ // Always respects the top coalesces amd removes useless coalesce below coalesce
+ val query1 = testRelation
+ .coalesce(10)
+ .coalesce(20)
+ val query2 = testRelation
+ .coalesce(30)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.coalesce(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
test("collapse two adjacent repartitions into one") {
- val query = testRelation
+ // Always respects the top repartition amd removes useless repartition below repartition
+ val query1 = testRelation
+ .repartition(10)
+ .repartition(20)
+ val query2 = testRelation
+ .repartition(30)
+ .repartition(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.repartition(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
+ test("coalesce above repartition") {
+ // Remove useless coalesce above repartition
+ val query1 = testRelation
.repartition(10)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val correctAnswer1 = testRelation.repartition(10).analyze
+
+ comparePlans(optimized1, correctAnswer1)
+
+ // No change in this case
+ val query2 = testRelation
+ .repartition(30)
+ .coalesce(20)
+
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer2 = query2.analyze
+
+ comparePlans(optimized2, correctAnswer2)
+ }
+
+ test("repartition above coalesce") {
+ // Always respects the top repartition amd removes useless coalesce below repartition
+ val query1 = testRelation
+ .coalesce(10)
+ .repartition(20)
+ val query2 = testRelation
+ .coalesce(30)
.repartition(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.repartition(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
- test("collapse repartition and repartitionBy into one") {
- val query = testRelation
+ test("repartitionBy above repartition") {
+ // Always respects the top repartitionBy amd removes useless repartition
+ val query1 = testRelation
.repartition(10)
.distribute('a)(20)
+ val query2 = testRelation
+ .repartition(30)
+ .distribute('a)(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
- test("collapse repartitionBy and repartition into one") {
- val query = testRelation
+ test("repartitionBy above coalesce") {
+ // Always respects the top repartitionBy amd removes useless coalesce below repartition
+ val query1 = testRelation
+ .coalesce(10)
+ .distribute('a)(20)
+ val query2 = testRelation
+ .coalesce(30)
.distribute('a)(20)
- .repartition(10)
- val optimized = Optimize.execute(query.analyze)
- val correctAnswer = testRelation.distribute('a)(10).analyze
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
+ test("repartition above repartitionBy") {
+ // Always respects the top repartition amd removes useless distribute below repartition
+ val query1 = testRelation
+ .distribute('a)(10)
+ .repartition(20)
+ val query2 = testRelation
+ .distribute('a)(30)
+ .repartition(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.repartition(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+
+ }
+
+ test("coalesce above repartitionBy") {
+ // Remove useless coalesce above repartition
+ val query1 = testRelation
+ .distribute('a)(10)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val correctAnswer1 = testRelation.distribute('a)(10).analyze
+
+ comparePlans(optimized1, correctAnswer1)
+
+ // No change in this case
+ val query2 = testRelation
+ .distribute('a)(30)
+ .coalesce(20)
+
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer2 = query2.analyze
+
+ comparePlans(optimized2, correctAnswer2)
}
test("collapse two adjacent repartitionBys into one") {
- val query = testRelation
+ // Always respects the top repartitionBy
+ val query1 = testRelation
.distribute('b)(10)
.distribute('a)(20)
+ val query2 = testRelation
+ .distribute('b)(30)
+ .distribute('a)(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f00311fc32..16edb35b1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2441,11 +2441,11 @@ class Dataset[T] private[sql](
}
/**
- * Returns a new Dataset that has exactly `numPartitions` partitions.
- * Similar to coalesce defined on an `RDD`, this operation results in a narrow dependency, e.g.
- * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
- * the 100 new partitions will claim 10 of the current partitions. If a larger number of
- * partitions is requested, it will stay at the current number of partitions.
+ * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
+ * are requested. If a larger number of partitions is requested, it will stay at the current
+ * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in
+ * a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not
+ * be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
*
* However, if you're doing a drastic coalesce, e.g. to numPartitions = 1,
* this may result in your computation taking place on fewer nodes than
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 0bfc92fdb6..02ccebd22b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -242,11 +242,12 @@ class PlannerSuite extends SharedSQLContext {
val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5)
def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length
assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3)
- assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1)
+ assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2)
doubleRepartitioned.queryExecution.optimizedPlan match {
- case r: Repartition =>
- assert(r.numPartitions === 5)
- assert(r.shuffle === false)
+ case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) =>
+ assert(numPartitions === 5)
+ assert(shuffle === false)
+ assert(shuffleChild === true)
}
}