aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala121
-rw-r--r--core/src/test/scala/org/apache/spark/PartitioningSuite.scala64
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala5
3 files changed, 171 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 52c018baa5..37053bb6f3 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -19,11 +19,15 @@ package org.apache.spark
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
-import scala.reflect.ClassTag
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.{ClassTag, classTag}
+import scala.util.hashing.byteswap32
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{PartitionPruningRDD, RDD}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{CollectionsUtils, Utils}
+import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils}
/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
@@ -103,26 +107,49 @@ class RangePartitioner[K : Ordering : ClassTag, V](
private var ascending: Boolean = true)
extends Partitioner {
+ // We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
+ require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")
+
private var ordering = implicitly[Ordering[K]]
// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
- if (partitions == 1) {
- Array()
+ if (partitions <= 1) {
+ Array.empty
} else {
- val rddSize = rdd.count()
- val maxSampleSize = partitions * 20.0
- val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
- val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted
- if (rddSample.length == 0) {
- Array()
+ // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
+ val sampleSize = math.min(20.0 * partitions, 1e6)
+ // Assume the input partitions are roughly balanced and over-sample a little bit.
+ val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
+ val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
+ if (numItems == 0L) {
+ Array.empty
} else {
- val bounds = new Array[K](partitions - 1)
- for (i <- 0 until partitions - 1) {
- val index = (rddSample.length - 1) * (i + 1) / partitions
- bounds(i) = rddSample(index)
+ // If a partition contains much more than the average number of items, we re-sample from it
+ // to ensure that enough items are collected from that partition.
+ val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
+ val candidates = ArrayBuffer.empty[(K, Float)]
+ val imbalancedPartitions = mutable.Set.empty[Int]
+ sketched.foreach { case (idx, n, sample) =>
+ if (fraction * n > sampleSizePerPartition) {
+ imbalancedPartitions += idx
+ } else {
+ // The weight is 1 over the sampling probability.
+ val weight = (n.toDouble / sample.size).toFloat
+ for (key <- sample) {
+ candidates += ((key, weight))
+ }
+ }
+ }
+ if (imbalancedPartitions.nonEmpty) {
+ // Re-sample imbalanced partitions with the desired sampling probability.
+ val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
+ val seed = byteswap32(-rdd.id - 1)
+ val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
+ val weight = (1.0 / fraction).toFloat
+ candidates ++= reSampled.map(x => (x, weight))
}
- bounds
+ RangePartitioner.determineBounds(candidates, partitions)
}
}
}
@@ -212,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
}
}
+
+private[spark] object RangePartitioner {
+
+ /**
+ * Sketches the input RDD via reservoir sampling on each partition.
+ *
+ * @param rdd the input RDD to sketch
+ * @param sampleSizePerPartition max sample size per partition
+ * @return (total number of items, an array of (partitionId, number of items, sample))
+ */
+ def sketch[K:ClassTag](
+ rdd: RDD[K],
+ sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
+ val shift = rdd.id
+ // val classTagK = classTag[K] // to avoid serializing the entire partitioner object
+ val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
+ val seed = byteswap32(idx ^ (shift << 16))
+ val (sample, n) = SamplingUtils.reservoirSampleAndCount(
+ iter, sampleSizePerPartition, seed)
+ Iterator((idx, n, sample))
+ }.collect()
+ val numItems = sketched.map(_._2.toLong).sum
+ (numItems, sketched)
+ }
+
+ /**
+ * Determines the bounds for range partitioning from candidates with weights indicating how many
+ * items each represents. Usually this is 1 over the probability used to sample this candidate.
+ *
+ * @param candidates unordered candidates with weights
+ * @param partitions number of partitions
+ * @return selected bounds
+ */
+ def determineBounds[K:Ordering:ClassTag](
+ candidates: ArrayBuffer[(K, Float)],
+ partitions: Int): Array[K] = {
+ val ordering = implicitly[Ordering[K]]
+ val ordered = candidates.sortBy(_._1)
+ val numCandidates = ordered.size
+ val sumWeights = ordered.map(_._2.toDouble).sum
+ val step = sumWeights / partitions
+ var cumWeight = 0.0
+ var target = step
+ val bounds = ArrayBuffer.empty[K]
+ var i = 0
+ var j = 0
+ var previousBound = Option.empty[K]
+ while ((i < numCandidates) && (j < partitions - 1)) {
+ val (key, weight) = ordered(i)
+ cumWeight += weight
+ if (cumWeight > target) {
+ // Skip duplicate values.
+ if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
+ bounds += key
+ target += step
+ j += 1
+ previousBound = Some(key)
+ }
+ }
+ i += 1
+ }
+ bounds.toArray
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 4658a08064..fc0cee3e87 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import scala.collection.mutable.ArrayBuffer
import scala.math.abs
import org.scalatest.{FunSuite, PrivateMethodTester}
@@ -52,14 +53,12 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(p2 === p2)
assert(p4 === p4)
- assert(p2 != p4)
- assert(p4 != p2)
+ assert(p2 === p4)
assert(p4 === anotherP4)
assert(anotherP4 === p4)
assert(descendingP2 === descendingP2)
assert(descendingP4 === descendingP4)
- assert(descendingP2 != descendingP4)
- assert(descendingP4 != descendingP2)
+ assert(descendingP2 === descendingP4)
assert(p2 != descendingP2)
assert(p4 != descendingP4)
assert(descendingP2 != p2)
@@ -102,6 +101,63 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
partitioner.getPartition(Row(100))
}
+ test("RangPartitioner.sketch") {
+ val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
+ val random = new java.util.Random(i)
+ Iterator.fill(i)(random.nextDouble())
+ }.cache()
+ val sampleSizePerPartition = 10
+ val (count, sketched) = RangePartitioner.sketch(rdd, sampleSizePerPartition)
+ assert(count === rdd.count())
+ sketched.foreach { case (idx, n, sample) =>
+ assert(n === idx)
+ assert(sample.size === math.min(n, sampleSizePerPartition))
+ }
+ }
+
+ test("RangePartitioner.determineBounds") {
+ assert(RangePartitioner.determineBounds(ArrayBuffer.empty[(Int, Float)], 10).isEmpty,
+ "Bounds on an empty candidates set should be empty.")
+ val candidates = ArrayBuffer(
+ (0.7, 2.0f), (0.1, 1.0f), (0.4, 1.0f), (0.3, 1.0f), (0.2, 1.0f), (0.5, 1.0f), (1.0, 3.0f))
+ assert(RangePartitioner.determineBounds(candidates, 3) === Array(0.4, 0.7))
+ }
+
+ test("RangePartitioner should run only one job if data is roughly balanced") {
+ val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
+ val random = new java.util.Random(i)
+ Iterator.fill(5000 * i)((random.nextDouble() + i, i))
+ }.cache()
+ for (numPartitions <- Seq(10, 20, 40)) {
+ val partitioner = new RangePartitioner(numPartitions, rdd)
+ assert(partitioner.numPartitions === numPartitions)
+ val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
+ assert(counts.max < 3.0 * counts.min)
+ }
+ }
+
+ test("RangePartitioner should work well on unbalanced data") {
+ val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
+ val random = new java.util.Random(i)
+ Iterator.fill(20 * i * i * i)((random.nextDouble() + i, i))
+ }.cache()
+ for (numPartitions <- Seq(2, 4, 8)) {
+ val partitioner = new RangePartitioner(numPartitions, rdd)
+ assert(partitioner.numPartitions === numPartitions)
+ val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
+ assert(counts.max < 3.0 * counts.min)
+ }
+ }
+
+ test("RangePartitioner should return a single partition for empty RDDs") {
+ val empty1 = sc.emptyRDD[(Int, Double)]
+ val partitioner1 = new RangePartitioner(0, empty1)
+ assert(partitioner1.numPartitions === 1)
+ val empty2 = sc.makeRDD(0 until 2, 2).flatMap(i => Seq.empty[(Int, Double)])
+ val partitioner2 = new RangePartitioner(2, empty2)
+ assert(partitioner2.numPartitions === 1)
+ }
+
test("HashPartitioner not equal to RangePartitioner") {
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 6654ec2d7c..fdc83bc0a5 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -613,6 +613,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("sort an empty RDD") {
+ val data = sc.emptyRDD[Int]
+ assert(data.sortBy(x => x).collect() === Array.empty)
+ }
+
test("sortByKey") {
val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))