aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-07-23 00:58:55 -0700
committerReynold Xin <rxin@apache.org>2014-07-23 00:58:55 -0700
commit4c7243e109c713bdfb87891748800109ffbaae07 (patch)
tree4e280534354b05c0a313336e51b9555e25d8c6d1
parent6c2be93f081f33e9e97e1231b0084a6a0eb4fa22 (diff)
downloadspark-4c7243e109c713bdfb87891748800109ffbaae07.tar.gz
spark-4c7243e109c713bdfb87891748800109ffbaae07.tar.bz2
spark-4c7243e109c713bdfb87891748800109ffbaae07.zip
[SPARK-2617] Correct doc and usages of preservesPartitioning
The name `preservesPartitioning` is ambiguous: 1) preserves the indices of partitions, 2) preserves the partitioner. The latter is correct and `preservesPartitioning` should really be called `preservesPartitioner` to avoid confusion. Unfortunately, this is already part of the API and we cannot change. We should be clear in the doc and fix wrong usages. This PR 1. adds notes in `maPartitions*`, 2. makes `RDD.sample` preserve partitioner, 3. changes `preservesPartitioning` to false in `RDD.zip` because the keys of the first RDD are no longer the keys of the zipped RDD, 4. fixes some wrong usages in MLlib. Author: Xiangrui Meng <meng@databricks.com> Closes #1526 from mengxr/preserve-partitioner and squashes the following commits: b361e65 [Xiangrui Meng] update doc based on pwendell's comments 3b1ba19 [Xiangrui Meng] update doc 357575c [Xiangrui Meng] fix unit test 20b4816 [Xiangrui Meng] Merge branch 'master' into preserve-partitioner d1caa65 [Xiangrui Meng] add doc to explain preservesPartitioning fix wrong usage of preservesPartitioning make sample preserse partitioning
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala4
8 files changed, 37 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index b5b8a5706d..a637d6f15b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -39,6 +39,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
*
* @param prev RDD to be sampled
* @param sampler a random sampler
+ * @param preservesPartitioning whether the sampler preserves the partitioner of the parent RDD
* @param seed random seed
* @tparam T input RDD item type
* @tparam U sampled RDD item type
@@ -46,9 +47,12 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
prev: RDD[T],
sampler: RandomSampler[T, U],
+ @transient preservesPartitioning: Boolean,
@transient seed: Long = Utils.random.nextLong)
extends RDD[U](prev) {
+ @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None
+
override def getPartitions: Array[Partition] = {
val random = new Random(seed)
firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a1f2827248..c1bafab3e7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -356,9 +356,9 @@ abstract class RDD[T: ClassTag](
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
- new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
+ new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
} else {
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed)
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
}
}
@@ -374,7 +374,7 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed)
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
}.toArray
}
@@ -586,6 +586,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -596,6 +599,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -607,6 +613,9 @@ abstract class RDD[T: ClassTag](
* :: DeveloperApi ::
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
@DeveloperApi
def mapPartitionsWithContext[U: ClassTag](
@@ -689,7 +698,7 @@ abstract class RDD[T: ClassTag](
* a map on the other).
*/
def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = {
- zipPartitions(other, true) { (thisIter, otherIter) =>
+ zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) =>
new Iterator[(T, U)] {
def hasNext = (thisIter.hasNext, otherIter.hasNext) match {
case (true, true) => true
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 5dd8de319a..a0483886f8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -43,7 +43,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
test("seed distribution") {
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
val sampler = new MockSampler
- val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
+ val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L)
assert(sample.distinct().count == 2, "Seeds must be different.")
}
@@ -52,7 +52,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
// We want to make sure there are no concurrency issues.
val rdd = sc.parallelize(0 until 111, 10)
for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
- val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
+ val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler, true)
sampled.zip(sampled).count()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2924de1129..6654ec2d7c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -523,6 +523,15 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedTopK === nums.sorted(ord).take(5))
}
+ test("sample preserves partitioner") {
+ val partitioner = new HashPartitioner(2)
+ val rdd = sc.parallelize(Seq((0, 1), (2, 3))).partitionBy(partitioner)
+ for (withReplacement <- Seq(true, false)) {
+ val sampled = rdd.sample(withReplacement, 1.0)
+ assert(sampled.partitioner === rdd.partitioner)
+ }
+ }
+
test("takeSample") {
val n = 1000000
val data = sc.parallelize(1 to n, 2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index 079743742d..1af40de2c7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -103,11 +103,11 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = false)
- val agg = counts.values.mapPartitions({ iter =>
+ val agg = counts.values.mapPartitions { iter =>
val agg = new BinaryLabelCounter()
iter.foreach(agg += _)
Iterator(agg)
- }, preservesPartitioning = true).collect()
+ }.collect()
val partitionwiseCumulativeCounts =
agg.scanLeft(new BinaryLabelCounter())(
(agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index f4c403bc78..8c2b044ea7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -377,9 +377,9 @@ class RowMatrix(
s"Only support dense matrix at this time but found ${B.getClass.getName}.")
val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray)
- val AB = rows.mapPartitions({ iter =>
+ val AB = rows.mapPartitions { iter =>
val Bi = Bb.value
- iter.map(row => {
+ iter.map { row =>
val v = BDV.zeros[Double](k)
var i = 0
while (i < k) {
@@ -387,8 +387,8 @@ class RowMatrix(
i += 1
}
Vectors.fromBreeze(v)
- })
- }, preservesPartitioning = true)
+ }
+ }
new RowMatrix(AB, nRows, B.numCols)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 15e8855db6..5356790cb5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -430,7 +430,7 @@ class ALS private (
val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner)
val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner)
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
- }, true)
+ }, preservesPartitioning = true)
val inLinks = links.mapValues(_._1)
val outLinks = links.mapValues(_._2)
inLinks.persist(StorageLevel.MEMORY_AND_DISK)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index aaf92a1a88..30de24ad89 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -264,8 +264,8 @@ object MLUtils {
(1 to numFolds).map { fold =>
val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
complement = false)
- val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
- val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
+ val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed)
+ val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed)
(training, validation)
}.toArray
}