aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-09-25 21:46:58 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-09-25 21:46:58 -0700
commit051785c7e67b7ba0f2f0b5e078753d3f4f380961 (patch)
tree5ff31cdbae7a7dd61fbf7f0a080771b3ca850d08
parent56c90485fd947d75bbe7aac81593ba42cfe56821 (diff)
downloadspark-051785c7e67b7ba0f2f0b5e078753d3f4f380961.tar.gz
spark-051785c7e67b7ba0f2f0b5e078753d3f4f380961.tar.bz2
spark-051785c7e67b7ba0f2f0b5e078753d3f4f380961.zip
Several fixes to sampling issues pointed out by Henry Milner:
- takeSample was biased towards earlier partitions - There were some range errors in takeSample - SampledRDDs with replacement didn't produce appropriate counts across partitions (we took exactly frac of each one)
-rw-r--r--core/src/main/scala/spark/RDD.scala13
-rw-r--r--core/src/main/scala/spark/SampledRDD.scala24
-rw-r--r--core/src/main/scala/spark/Utils.scala26
3 files changed, 35 insertions, 28 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index efe248896a..5fac955286 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -145,8 +145,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var initialCount = count()
var maxSelected = 0
- if (initialCount > Integer.MAX_VALUE) {
- maxSelected = Integer.MAX_VALUE
+ if (initialCount > Integer.MAX_VALUE - 1) {
+ maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
}
@@ -161,15 +161,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
total = num
}
- var samples = this.sample(withReplacement, fraction, seed).collect()
+ val rand = new Random(seed)
+ var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
while (samples.length < total) {
- samples = this.sample(withReplacement, fraction, seed).collect()
+ samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
}
- val arr = samples.take(total)
-
- return arr
+ Utils.randomizeInPlace(samples, rand).take(total)
}
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
diff --git a/core/src/main/scala/spark/SampledRDD.scala b/core/src/main/scala/spark/SampledRDD.scala
index 8ef40d8d9e..c066017e89 100644
--- a/core/src/main/scala/spark/SampledRDD.scala
+++ b/core/src/main/scala/spark/SampledRDD.scala
@@ -1,6 +1,8 @@
package spark
import java.util.Random
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
override val index: Int = prev.index
@@ -28,19 +30,21 @@ class SampledRDD[T: ClassManifest](
override def compute(splitIn: Split) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
- val rg = new Random(split.seed)
- // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
if (withReplacement) {
- val oldData = prev.iterator(split.prev).toArray
- val sampleSize = (oldData.size * frac).ceil.toInt
- val sampledData = {
- // all of oldData's indices are candidates, even if sampleSize < oldData.size
- for (i <- 1 to sampleSize)
- yield oldData(rg.nextInt(oldData.size))
+ // For large datasets, the expected number of occurrences of each element in a sample with
+ // replacement is Poisson(frac). We use that to get a count for each element.
+ val poisson = new Poisson(frac, new DRand(split.seed))
+ prev.iterator(split.prev).flatMap { element =>
+ val count = poisson.nextInt()
+ if (count == 0) {
+ Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
+ } else {
+ Iterator.fill(count)(element)
+ }
}
- sampledData.iterator
} else { // Sampling without replacement
- prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac))
+ val rand = new Random(split.seed)
+ prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
}
}
}
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 5a3f8bde43..eb7d69e816 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -2,12 +2,11 @@ package spark
import java.io._
import java.net.{InetAddress, URL, URI}
-import java.util.{Locale, UUID}
+import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
import scala.io.Source
/**
@@ -172,17 +171,22 @@ object Utils extends Logging {
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
*/
- def randomize[T](seq: TraversableOnce[T]): Seq[T] = {
- val buf = new ArrayBuffer[T]()
- buf ++= seq
- val rand = new Random()
- for (i <- (buf.size - 1) to 1 by -1) {
+ def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
+ randomizeInPlace(seq.toArray)
+ }
+
+ /**
+ * Shuffle the elements of an array into a random order, modifying the
+ * original array. Returns the original array.
+ */
+ def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
+ for (i <- (arr.length - 1) to 1 by -1) {
val j = rand.nextInt(i)
- val tmp = buf(j)
- buf(j) = buf(i)
- buf(i) = tmp
+ val tmp = arr(j)
+ arr(j) = arr(i)
+ arr(i) = tmp
}
- buf
+ arr
}
/**