aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala17
2 files changed, 17 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index b5b8a5706d..a637d6f15b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -39,6 +39,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
*
* @param prev RDD to be sampled
* @param sampler a random sampler
+ * @param preservesPartitioning whether the sampler preserves the partitioner of the parent RDD
* @param seed random seed
* @tparam T input RDD item type
* @tparam U sampled RDD item type
@@ -46,9 +47,12 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
prev: RDD[T],
sampler: RandomSampler[T, U],
+ @transient preservesPartitioning: Boolean,
@transient seed: Long = Utils.random.nextLong)
extends RDD[U](prev) {
+ @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None
+
override def getPartitions: Array[Partition] = {
val random = new Random(seed)
firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a1f2827248..c1bafab3e7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -356,9 +356,9 @@ abstract class RDD[T: ClassTag](
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
- new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
+ new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
} else {
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed)
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
}
}
@@ -374,7 +374,7 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed)
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
}.toArray
}
@@ -586,6 +586,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -596,6 +599,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -607,6 +613,9 @@ abstract class RDD[T: ClassTag](
* :: DeveloperApi ::
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
@DeveloperApi
def mapPartitionsWithContext[U: ClassTag](
@@ -689,7 +698,7 @@ abstract class RDD[T: ClassTag](
* a map on the other).
*/
def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = {
- zipPartitions(other, true) { (thisIter, otherIter) =>
+ zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) =>
new Iterator[(T, U)] {
def hasNext = (thisIter.hasNext, otherIter.hasNext) match {
case (true, true) => true