aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala16
1 files changed, 9 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 98c3abe93b..93dfbc0e6e 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -55,14 +55,16 @@ object Partitioner {
* We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
*/
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
- val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.length).reverse
- for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) {
- return r.partitioner.get
- }
- if (rdd.context.conf.contains("spark.default.parallelism")) {
- new HashPartitioner(rdd.context.defaultParallelism)
+ val rdds = (Seq(rdd) ++ others)
+ val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0))
+ if (hasPartitioner.nonEmpty) {
+ hasPartitioner.maxBy(_.partitions.length).partitioner.get
} else {
- new HashPartitioner(bySize.head.partitions.length)
+ if (rdd.context.conf.contains("spark.default.parallelism")) {
+ new HashPartitioner(rdd.context.defaultParallelism)
+ } else {
+ new HashPartitioner(rdds.map(_.partitions.length).max)
+ }
}
}
}