aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJustin Ma <jtma@eecs.berkeley.edu>2010-08-18 15:59:35 -0700
committerJustin Ma <jtma@eecs.berkeley.edu>2010-08-18 15:59:35 -0700
commitea8c2785ddd0439cb0433dde08d2a7caecfc06cb (patch)
tree93468dd20520df6b71a500ba7694d65bd3d37b39 /src
parent156bccbe230d36af93fe46f8942c3ac1d18635ae (diff)
downloadspark-ea8c2785ddd0439cb0433dde08d2a7caecfc06cb.tar.gz
spark-ea8c2785ddd0439cb0433dde08d2a7caecfc06cb.tar.bz2
spark-ea8c2785ddd0439cb0433dde08d2a7caecfc06cb.zip
now we have sampling with replacement (at least on a per-split basis)
Diffstat (limited to 'src')
-rw-r--r--src/scala/spark/RDD.scala20
1 files changed, 17 insertions, 3 deletions
diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala
index ba4bff3c4b..8dd7b104fb 100644
--- a/src/scala/spark/RDD.scala
+++ b/src/scala/spark/RDD.scala
@@ -27,7 +27,7 @@ abstract class RDD[T: ClassManifest, Split](
def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f))
def aggregateSplit() = new SplitRDD(this)
def cache() = new CachedRDD(this)
- def sample(frac: Double, seed: Int) = new SampledRDD(this, frac, seed)
+ def sample(withReplacement: Boolean, frac: Double, seed: Int) = new SampledRDD(this, withReplacement, frac, seed)
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
@@ -153,15 +153,29 @@ extends RDD[Array[T], Split](prev.sparkContext) {
@serializable class SeededSplit[Split](val prev: Split, val seed: Int) {}
class SampledRDD[T: ClassManifest, Split](
- prev: RDD[T, Split], frac: Double, seed: Int)
+ prev: RDD[T, Split], withReplacement: Boolean, frac: Double, seed: Int)
extends RDD[T, SeededSplit[Split]](prev.sparkContext) {
@transient val splits_ = { val rg = new Random(seed); prev.splits.map(x => new SeededSplit(x, rg.nextInt)) }
override def splits = splits_
override def preferredLocations(split: SeededSplit[Split]) = prev.preferredLocations(split.prev)
- override def iterator(split: SeededSplit[Split]) = { val rg = new Random(split.seed); prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac)) }
+ override def iterator(split: SeededSplit[Split]) = {
+ val rg = new Random(split.seed);
+ // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
+ if (withReplacement) {
+ val oldData = prev.iterator(split.prev).toArray
+ val sampleSize = (oldData.size * frac).ceil.toInt
+ val sampledData = for (i <- 1 to sampleSize) yield oldData(rg.nextInt(oldData.size)) // all of oldData's indices are candidates, even if sampleSize < oldData.size
+ sampledData.iterator
+ }
+ // Sampling without replacement
+ else {
+ prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac))
+ }
+ }
override def taskStarted(split: SeededSplit[Split], slot: SlaveOffer) = prev.taskStarted(split.prev, slot)
}
+
class CachedRDD[T, Split](
prev: RDD[T, Split])(implicit m: ClassManifest[T])
extends RDD[T, Split](prev.sparkContext) {