From 051785c7e67b7ba0f2f0b5e078753d3f4f380961 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 25 Sep 2012 21:46:58 -0700 Subject: Several fixes to sampling issues pointed out by Henry Milner: - takeSample was biased towards earlier partitions - There were some range errors in takeSample - SampledRDDs with replacement didn't produce appropriate counts across partitions (we took exactly frac of each one) --- core/src/main/scala/spark/RDD.scala | 13 ++++++------- core/src/main/scala/spark/SampledRDD.scala | 24 ++++++++++++++---------- core/src/main/scala/spark/Utils.scala | 26 +++++++++++++++----------- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index efe248896a..5fac955286 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -145,8 +145,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial var initialCount = count() var maxSelected = 0 - if (initialCount > Integer.MAX_VALUE) { - maxSelected = Integer.MAX_VALUE + if (initialCount > Integer.MAX_VALUE - 1) { + maxSelected = Integer.MAX_VALUE - 1 } else { maxSelected = initialCount.toInt } @@ -161,15 +161,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial total = num } - var samples = this.sample(withReplacement, fraction, seed).collect() + val rand = new Random(seed) + var samples = this.sample(withReplacement, fraction, rand.nextInt).collect() while (samples.length < total) { - samples = this.sample(withReplacement, fraction, seed).collect() + samples = this.sample(withReplacement, fraction, rand.nextInt).collect() } - val arr = samples.take(total) - - return arr + Utils.randomizeInPlace(samples, rand).take(total) } def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) diff --git a/core/src/main/scala/spark/SampledRDD.scala b/core/src/main/scala/spark/SampledRDD.scala index 8ef40d8d9e..c066017e89 100644 --- a/core/src/main/scala/spark/SampledRDD.scala +++ b/core/src/main/scala/spark/SampledRDD.scala @@ -1,6 +1,8 @@ package spark import java.util.Random +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { override val index: Int = prev.index @@ -28,19 +30,21 @@ class SampledRDD[T: ClassManifest]( override def compute(splitIn: Split) = { val split = splitIn.asInstanceOf[SampledRDDSplit] - val rg = new Random(split.seed) - // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?) if (withReplacement) { - val oldData = prev.iterator(split.prev).toArray - val sampleSize = (oldData.size * frac).ceil.toInt - val sampledData = { - // all of oldData's indices are candidates, even if sampleSize < oldData.size - for (i <- 1 to sampleSize) - yield oldData(rg.nextInt(oldData.size)) + // For large datasets, the expected number of occurrences of each element in a sample with + // replacement is Poisson(frac). We use that to get a count for each element. + val poisson = new Poisson(frac, new DRand(split.seed)) + prev.iterator(split.prev).flatMap { element => + val count = poisson.nextInt() + if (count == 0) { + Iterator.empty // Avoid object allocation when we return 0 items, which is quite often + } else { + Iterator.fill(count)(element) + } } - sampledData.iterator } else { // Sampling without replacement - prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac)) + val rand = new Random(split.seed) + prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 5a3f8bde43..eb7d69e816 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -2,12 +2,11 @@ package spark import java.io._ import java.net.{InetAddress, URL, URI} -import java.util.{Locale, UUID} +import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.ArrayBuffer -import scala.util.Random import scala.io.Source /** @@ -172,17 +171,22 @@ object Utils extends Logging { * result in a new collection. Unlike scala.util.Random.shuffle, this method * uses a local random number generator, avoiding inter-thread contention. */ - def randomize[T](seq: TraversableOnce[T]): Seq[T] = { - val buf = new ArrayBuffer[T]() - buf ++= seq - val rand = new Random() - for (i <- (buf.size - 1) to 1 by -1) { + def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { + randomizeInPlace(seq.toArray) + } + + /** + * Shuffle the elements of an array into a random order, modifying the + * original array. Returns the original array. + */ + def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { + for (i <- (arr.length - 1) to 1 by -1) { val j = rand.nextInt(i) - val tmp = buf(j) - buf(j) = buf(i) - buf(i) = tmp + val tmp = arr(j) + arr(j) = arr(i) + arr(i) = tmp } - buf + arr } /** -- cgit v1.2.3