diff options
author | Justin Ma <jtma@eecs.berkeley.edu> | 2010-08-18 15:59:35 -0700 |
---|---|---|
committer | Justin Ma <jtma@eecs.berkeley.edu> | 2010-08-18 15:59:35 -0700 |
commit | ea8c2785ddd0439cb0433dde08d2a7caecfc06cb (patch) | |
tree | 93468dd20520df6b71a500ba7694d65bd3d37b39 /src | |
parent | 156bccbe230d36af93fe46f8942c3ac1d18635ae (diff) | |
download | spark-ea8c2785ddd0439cb0433dde08d2a7caecfc06cb.tar.gz spark-ea8c2785ddd0439cb0433dde08d2a7caecfc06cb.tar.bz2 spark-ea8c2785ddd0439cb0433dde08d2a7caecfc06cb.zip |
now we have sampling with replacement (at least on a per-split basis)
Diffstat (limited to 'src')
-rw-r--r-- | src/scala/spark/RDD.scala | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala index ba4bff3c4b..8dd7b104fb 100644 --- a/src/scala/spark/RDD.scala +++ b/src/scala/spark/RDD.scala @@ -27,7 +27,7 @@ abstract class RDD[T: ClassManifest, Split]( def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f)) def aggregateSplit() = new SplitRDD(this) def cache() = new CachedRDD(this) - def sample(frac: Double, seed: Int) = new SampledRDD(this, frac, seed) + def sample(withReplacement: Boolean, frac: Double, seed: Int) = new SampledRDD(this, withReplacement, frac, seed) def foreach(f: T => Unit) { val cleanF = sc.clean(f) @@ -153,15 +153,29 @@ extends RDD[Array[T], Split](prev.sparkContext) { @serializable class SeededSplit[Split](val prev: Split, val seed: Int) {} class SampledRDD[T: ClassManifest, Split]( - prev: RDD[T, Split], frac: Double, seed: Int) + prev: RDD[T, Split], withReplacement: Boolean, frac: Double, seed: Int) extends RDD[T, SeededSplit[Split]](prev.sparkContext) { @transient val splits_ = { val rg = new Random(seed); prev.splits.map(x => new SeededSplit(x, rg.nextInt)) } override def splits = splits_ override def preferredLocations(split: SeededSplit[Split]) = prev.preferredLocations(split.prev) - override def iterator(split: SeededSplit[Split]) = { val rg = new Random(split.seed); prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac)) } + override def iterator(split: SeededSplit[Split]) = { + val rg = new Random(split.seed); + // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?) + if (withReplacement) { + val oldData = prev.iterator(split.prev).toArray + val sampleSize = (oldData.size * frac).ceil.toInt + val sampledData = for (i <- 1 to sampleSize) yield oldData(rg.nextInt(oldData.size)) // all of oldData's indices are candidates, even if sampleSize < oldData.size + sampledData.iterator + } + // Sampling without replacement + else { + prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac)) + } + } override def taskStarted(split: SeededSplit[Split], slot: SlaveOffer) = prev.taskStarted(split.prev, slot) } + class CachedRDD[T, Split]( prev: RDD[T, Split])(implicit m: ClassManifest[T]) extends RDD[T, Split](prev.sparkContext) { |