aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala8
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala8
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala19
3 files changed, 29 insertions, 6 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index cc3cca2571..18b4a1eca4 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -439,12 +439,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs. If any of
* the RDDs already has a partitioner, choose that one, otherwise use a default HashPartitioner.
+ *
+ * The number of partitions will be the same as the number of partitions in the largest upstream
+ * RDD, as this should be least likely to cause out-of-memory errors.
*/
def defaultPartitioner(rdds: RDD[_]*): Partitioner = {
- for (r <- rdds if r.partitioner != None) {
+ val bySize = rdds.sortBy(_.splits.size).reverse
+ for (r <- bySize if r.partitioner != None) {
return r.partitioner.get
}
- return new HashPartitioner(self.context.defaultParallelism)
+ return new HashPartitioner(bySize.head.splits.size)
}
/**
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index af1107cd19..60db759c25 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -84,10 +84,10 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner)
assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner)
- assert(grouped2.join(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.cogroup(grouped4).partitioner === grouped2.partitioner)
+ assert(grouped2.join(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner)
assert(grouped2.join(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 3493b9511f..ab7060a1ac 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -211,6 +211,25 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(rdd.keys.collect().toList === List(1, 2))
assert(rdd.values.collect().toList === List("a", "b"))
}
+
+ test("default partition size uses split size") {
+ sc = new SparkContext("local", "test")
+ // specify 2000 splits
+ val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
+ // do a map, which loses the partitioner
+ val b = a.map(a => (a, (a * 2).toString))
+ // then a group by, and see we didn't revert to 2 splits
+ val c = b.groupByKey()
+ assert(c.splits.size === 2000)
+ }
+
+ test("default partition uses largest partitioner") {
+ sc = new SparkContext("local", "test")
+ val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
+ val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
+ val c = a.join(b)
+ assert(c.splits.size === 2000)
+ }
}
object ShuffleSuite {