diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/Partitioner.scala | 61 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/PartitioningSuite.scala | 34 |
2 files changed, 95 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 9155159cf6..6274796061 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -156,3 +156,64 @@ class RangePartitioner[K : Ordering : ClassTag, V]( false } } + +/** + * A [[org.apache.spark.Partitioner]] that partitions records into specified bounds + * Default value is 1000. Once all partitions have bounds elements, the partitioner + * allocates 1 element per partition so eventually the smaller partitions are at most + * off by 1 key compared to the larger partitions. + */ +class BoundaryPartitioner[K : Ordering : ClassTag, V]( + partitions: Int, + @transient rdd: RDD[_ <: Product2[K,V]], + private val boundary: Int = 1000) + extends Partitioner { + + // this array keeps track of keys assigned to a partition + // counts[0] refers to # of keys in partition 0 and so on + private val counts: Array[Int] = { + new Array[Int](numPartitions) + } + + def numPartitions = math.abs(partitions) + + /* + * Ideally, this should've been calculated based on # partitions and total keys + * But we are not calling count on RDD here to avoid calling an action. + * User has the flexibility of calling count and passing in any appropriate boundary + */ + def keysPerPartition = boundary + + var currPartition = 0 + + /* + * Pick current partition for the key until we hit the bound for keys / partition, + * start allocating to next partition at that time. + * + * NOTE: In case where we have lets say 2000 keys and user says 3 partitions with 500 + * passed in as boundary, the first 500 will goto P1, 501-1000 go to P2, 1001-1500 go to P3, + * after that, next keys go to one partition at a time. So 1501 goes to P1, 1502 goes to P2, + * 1503 goes to P3 and so on. + */ + def getPartition(key: Any): Int = { + val partition = currPartition + counts(partition) = counts(partition) + 1 + /* + * Since we are filling up a partition before moving to next one (this helps in maintaining + * order of keys, in certain cases, it is possible to end up with empty partitions, like + * 3 partitions, 500 keys / partition and if rdd has 700 keys, 1 partition will be entirely + * empty. + */ + if(counts(currPartition) >= keysPerPartition) + currPartition = (currPartition + 1) % numPartitions + partition + } + + override def equals(other: Any): Boolean = other match { + case r: BoundaryPartitioner[_,_] => + (r.counts.sameElements(counts) && r.boundary == boundary + && r.currPartition == currPartition) + case _ => + false + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 7c30626a0c..7d40395803 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -66,6 +66,40 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(descendingP4 != p4) } + test("BoundaryPartitioner equality") { + // Make an RDD where all the elements are the same so that the partition range bounds + // are deterministically all the same. + val rdd = sc.parallelize(1.to(4000)).map(x => (x, x)) + + val p2 = new BoundaryPartitioner(2, rdd, 1000) + val p4 = new BoundaryPartitioner(4, rdd, 1000) + val anotherP4 = new BoundaryPartitioner(4, rdd) + + assert(p2 === p2) + assert(p4 === p4) + assert(p2 != p4) + assert(p4 != p2) + assert(p4 === anotherP4) + assert(anotherP4 === p4) + } + + test("BoundaryPartitioner getPartition") { + val rdd = sc.parallelize(1.to(2000)).map(x => (x, x)) + val partitioner = new BoundaryPartitioner(4, rdd, 500) + 1.to(2000).map { element => { + val partition = partitioner.getPartition(element) + if (element <= 500) { + assert(partition === 0) + } else if (element > 501 && element <= 1000) { + assert(partition === 1) + } else if (element > 1001 && element <= 1500) { + assert(partition === 2) + } else if (element > 1501 && element <= 2000) { + assert(partition === 3) + } + }} + } + test("RangePartitioner getPartition") { val rdd = sc.parallelize(1.to(2000)).map(x => (x, x)) // We have different behaviour of getPartition for partitions with less than 1000 and more than |