aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala9
4 files changed, 28 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index b5b8a5706d..a637d6f15b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -39,6 +39,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
*
* @param prev RDD to be sampled
* @param sampler a random sampler
+ * @param preservesPartitioning whether the sampler preserves the partitioner of the parent RDD
* @param seed random seed
* @tparam T input RDD item type
* @tparam U sampled RDD item type
@@ -46,9 +47,12 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
prev: RDD[T],
sampler: RandomSampler[T, U],
+ @transient preservesPartitioning: Boolean,
@transient seed: Long = Utils.random.nextLong)
extends RDD[U](prev) {
+ @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None
+
override def getPartitions: Array[Partition] = {
val random = new Random(seed)
firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a1f2827248..c1bafab3e7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -356,9 +356,9 @@ abstract class RDD[T: ClassTag](
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
- new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
+ new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
} else {
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed)
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
}
}
@@ -374,7 +374,7 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed)
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
}.toArray
}
@@ -586,6 +586,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -596,6 +599,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -607,6 +613,9 @@ abstract class RDD[T: ClassTag](
* :: DeveloperApi ::
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
@DeveloperApi
def mapPartitionsWithContext[U: ClassTag](
@@ -689,7 +698,7 @@ abstract class RDD[T: ClassTag](
* a map on the other).
*/
def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = {
- zipPartitions(other, true) { (thisIter, otherIter) =>
+ zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) =>
new Iterator[(T, U)] {
def hasNext = (thisIter.hasNext, otherIter.hasNext) match {
case (true, true) => true
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 5dd8de319a..a0483886f8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -43,7 +43,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
test("seed distribution") {
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
val sampler = new MockSampler
- val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
+ val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L)
assert(sample.distinct().count == 2, "Seeds must be different.")
}
@@ -52,7 +52,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
// We want to make sure there are no concurrency issues.
val rdd = sc.parallelize(0 until 111, 10)
for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
- val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
+ val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler, true)
sampled.zip(sampled).count()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2924de1129..6654ec2d7c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -523,6 +523,15 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedTopK === nums.sorted(ord).take(5))
}
+ test("sample preserves partitioner") {
+ val partitioner = new HashPartitioner(2)
+ val rdd = sc.parallelize(Seq((0, 1), (2, 3))).partitionBy(partitioner)
+ for (withReplacement <- Seq(true, false)) {
+ val sampled = rdd.sample(withReplacement, 1.0)
+ assert(sampled.partitioner === rdd.partitioner)
+ }
+ }
+
test("takeSample") {
val n = 1000000
val data = sc.parallelize(1 to n, 2)