aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-07-23 00:58:55 -0700
committerReynold Xin <rxin@apache.org>2014-07-23 00:58:55 -0700
commit4c7243e109c713bdfb87891748800109ffbaae07 (patch)
tree4e280534354b05c0a313336e51b9555e25d8c6d1 /core/src/main/scala
parent6c2be93f081f33e9e97e1231b0084a6a0eb4fa22 (diff)
downloadspark-4c7243e109c713bdfb87891748800109ffbaae07.tar.gz
spark-4c7243e109c713bdfb87891748800109ffbaae07.tar.bz2
spark-4c7243e109c713bdfb87891748800109ffbaae07.zip
[SPARK-2617] Correct doc and usages of preservesPartitioning
The name `preservesPartitioning` is ambiguous: 1) preserves the indices of partitions, 2) preserves the partitioner. The latter is correct and `preservesPartitioning` should really be called `preservesPartitioner` to avoid confusion. Unfortunately, this is already part of the API and we cannot change. We should be clear in the doc and fix wrong usages. This PR 1. adds notes in `maPartitions*`, 2. makes `RDD.sample` preserve partitioner, 3. changes `preservesPartitioning` to false in `RDD.zip` because the keys of the first RDD are no longer the keys of the zipped RDD, 4. fixes some wrong usages in MLlib. Author: Xiangrui Meng <meng@databricks.com> Closes #1526 from mengxr/preserve-partitioner and squashes the following commits: b361e65 [Xiangrui Meng] update doc based on pwendell's comments 3b1ba19 [Xiangrui Meng] update doc 357575c [Xiangrui Meng] fix unit test 20b4816 [Xiangrui Meng] Merge branch 'master' into preserve-partitioner d1caa65 [Xiangrui Meng] add doc to explain preservesPartitioning fix wrong usage of preservesPartitioning make sample preserse partitioning
Diffstat (limited to 'core/src/main/scala')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala17
2 files changed, 17 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index b5b8a5706d..a637d6f15b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -39,6 +39,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
*
* @param prev RDD to be sampled
* @param sampler a random sampler
+ * @param preservesPartitioning whether the sampler preserves the partitioner of the parent RDD
* @param seed random seed
* @tparam T input RDD item type
* @tparam U sampled RDD item type
@@ -46,9 +47,12 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
prev: RDD[T],
sampler: RandomSampler[T, U],
+ @transient preservesPartitioning: Boolean,
@transient seed: Long = Utils.random.nextLong)
extends RDD[U](prev) {
+ @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None
+
override def getPartitions: Array[Partition] = {
val random = new Random(seed)
firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a1f2827248..c1bafab3e7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -356,9 +356,9 @@ abstract class RDD[T: ClassTag](
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
- new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
+ new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
} else {
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed)
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
}
}
@@ -374,7 +374,7 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed)
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
}.toArray
}
@@ -586,6 +586,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -596,6 +599,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -607,6 +613,9 @@ abstract class RDD[T: ClassTag](
* :: DeveloperApi ::
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
@DeveloperApi
def mapPartitionsWithContext[U: ClassTag](
@@ -689,7 +698,7 @@ abstract class RDD[T: ClassTag](
* a map on the other).
*/
def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = {
- zipPartitions(other, true) { (thisIter, otherIter) =>
+ zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) =>
new Iterator[(T, U)] {
def hasNext = (thisIter.hasNext, otherIter.hasNext) match {
case (true, true) => true