diff options
-rw-r--r-- | core/src/main/scala/spark/Partitioner.scala | 23 | ||||
-rw-r--r-- | core/src/test/scala/spark/ShuffleSuite.scala | 2 |
2 files changed, 19 insertions, 6 deletions
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index 03966f1c96..eec0e8dd79 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -10,12 +10,21 @@ abstract class Partitioner extends Serializable { } object Partitioner { + + private val useDefaultParallelism = System.getProperty("spark.default.parallelism") != null + /** - * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. If any of - * the RDDs already has a partitioner, choose that one, otherwise use a default HashPartitioner. + * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. + * + * If any of the RDDs already has a partitioner, choose that one. * - * The number of partitions will be the same as the number of partitions in the largest upstream - * RDD, as this should be least likely to cause out-of-memory errors. + * Otherwise, we use a default HashPartitioner. For the number of partitions, if + * spark.default.parallelism is set, then we'll use the value from SparkContext + * defaultParallelism, otherwise we'll use the max number of upstream partitions. + * + * Unless spark.default.parallelism is set, He number of partitions will be the + * same as the number of partitions in the largest upstream RDD, as this should + * be least likely to cause out-of-memory errors. * * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. */ @@ -24,7 +33,11 @@ object Partitioner { for (r <- bySize if r.partitioner != None) { return r.partitioner.get } - return new HashPartitioner(bySize.head.partitions.size) + if (useDefaultParallelism) { + return new HashPartitioner(rdd.context.defaultParallelism) + } else { + return new HashPartitioner(bySize.head.partitions.size) + } } } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 2099999ed7..8411291b2c 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -235,7 +235,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(rdd.values.collect().toList === List("a", "b")) } - test("default partitioner uses split size") { + test("default partitioner uses partition size") { sc = new SparkContext("local", "test") // specify 2000 partitions val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) |