aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala33
2 files changed, 46 insertions, 2 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a1ca612cc9..aa03e9276f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -328,11 +328,22 @@ abstract class RDD[T: ClassTag](
def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null)
: RDD[T] = {
if (shuffle) {
+ /** Distributes elements evenly across output partitions, starting from a random partition. */
+ def distributePartition(index: Int, items: Iterator[T]): Iterator[(Int, T)] = {
+ var position = (new Random(index)).nextInt(numPartitions)
+ items.map { t =>
+ // Note that the hash code of the key will just be the key itself. The HashPartitioner
+ // will mod it with the number of total partitions.
+ position = position + 1
+ (position, t)
+ }
+ }
+
// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
- new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)),
+ new ShuffledRDD[Int, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)),
- numPartitions).keys
+ numPartitions).values
} else {
new CoalescedRDD(this, numPartitions)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 8da9a0da70..e686068f7a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -202,6 +202,39 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
}
+ test("repartitioned RDDs perform load balancing") {
+ // Coalesce partitions
+ val input = Array.fill(1000)(1)
+ val initialPartitions = 10
+ val data = sc.parallelize(input, initialPartitions)
+
+ val repartitioned1 = data.repartition(2)
+ assert(repartitioned1.partitions.size == 2)
+ val partitions1 = repartitioned1.glom().collect()
+ // some noise in balancing is allowed due to randomization
+ assert(math.abs(partitions1(0).length - 500) < initialPartitions)
+ assert(math.abs(partitions1(1).length - 500) < initialPartitions)
+ assert(repartitioned1.collect() === input)
+
+ def testSplitPartitions(input: Seq[Int], initialPartitions: Int, finalPartitions: Int) {
+ val data = sc.parallelize(input, initialPartitions)
+ val repartitioned = data.repartition(finalPartitions)
+ assert(repartitioned.partitions.size === finalPartitions)
+ val partitions = repartitioned.glom().collect()
+ // assert all elements are present
+ assert(repartitioned.collect().sortWith(_ > _).toSeq === input.toSeq.sortWith(_ > _).toSeq)
+ // assert no bucket is overloaded
+ for (partition <- partitions) {
+ val avg = input.size / finalPartitions
+ val maxPossible = avg + initialPartitions
+ assert(partition.length <= maxPossible)
+ }
+ }
+
+ testSplitPartitions(Array.fill(100)(1), 10, 20)
+ testSplitPartitions(Array.fill(10000)(1) ++ Array.fill(10000)(2), 20, 100)
+ }
+
test("coalesced RDDs") {
val data = sc.parallelize(1 to 10, 10)