From a3bc012af8a4f7c362a28bc1294942019d5a288d Mon Sep 17 00:00:00 2001 From: Edison Tung Date: Mon, 21 Nov 2011 16:38:44 -0800 Subject: added takeSamples method takeSamples method takes a specified number of samples from the RDD and outputs it in an array. --- core/src/main/scala/spark/RDD.scala | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 445d520bc2..7331e40cf7 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -91,6 +91,39 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = new SampledRDD(this, withReplacement, fraction, seed) + def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { + var fraction = 0.0 + var total = 0 + var multiplier = 3.0 + + if (num > count()) { + total = Math.min(count().toInt) + fraction = 1.0 + } + else if (num < 0) { + throw(new IllegalArgumentException()) + } + else { + fraction = Math.min(multiplier*(num+1)/count(), 1.0) + total = num.toInt + } + + var r = new SampledRDD(this, withReplacement, fraction, seed) + + while (r.count() < total) { + r = new SampledRDD(this, withReplacement, fraction, seed) + } + + var samples = r.collect() + var arr = new Array[T](total) + + for (i <- 0 to total - 1) { + arr(i) = samples(i) + } + + return arr + } + def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) def ++(other: RDD[T]): RDD[T] = this.union(other) -- cgit v1.2.3