aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEdison Tung <edisontung@gmail.com>2011-11-21 16:38:44 -0800
committerEdison Tung <edisontung@gmail.com>2011-11-21 16:38:44 -0800
commita3bc012af8a4f7c362a28bc1294942019d5a288d (patch)
treeee670ebe305934ae801ae52ec7a65291ca48a34d
parent3b9d9de583bf2ee0c7b46c75944aedfcfa784a02 (diff)
downloadspark-a3bc012af8a4f7c362a28bc1294942019d5a288d.tar.gz
spark-a3bc012af8a4f7c362a28bc1294942019d5a288d.tar.bz2
spark-a3bc012af8a4f7c362a28bc1294942019d5a288d.zip
added takeSamples method
takeSamples method takes a specified number of samples from the RDD and outputs it in an array.
-rw-r--r--core/src/main/scala/spark/RDD.scala33
1 files changed, 33 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 445d520bc2..7331e40cf7 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -91,6 +91,39 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
new SampledRDD(this, withReplacement, fraction, seed)
+ def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
+ var fraction = 0.0
+ var total = 0
+ var multiplier = 3.0
+
+ if (num > count()) {
+ total = Math.min(count().toInt)
+ fraction = 1.0
+ }
+ else if (num < 0) {
+ throw(new IllegalArgumentException())
+ }
+ else {
+ fraction = Math.min(multiplier*(num+1)/count(), 1.0)
+ total = num.toInt
+ }
+
+ var r = new SampledRDD(this, withReplacement, fraction, seed)
+
+ while (r.count() < total) {
+ r = new SampledRDD(this, withReplacement, fraction, seed)
+ }
+
+ var samples = r.collect()
+ var arr = new Array[T](total)
+
+ for (i <- 0 to total - 1) {
+ arr(i) = samples(i)
+ }
+
+ return arr
+ }
+
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
def ++(other: RDD[T]): RDD[T] = this.union(other)