aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-07-18 12:41:50 -0700
committerReynold Xin <rxin@apache.org>2014-07-18 12:41:50 -0700
commit586e716e47305cd7c2c3ff35c0e828b63ef2f6a8 (patch)
tree0ebe507c73e48c12a4f7863b49e177c17c539098 /core/src
parent7f87ab98138d00723e007471f1a7f506650978cb (diff)
downloadspark-586e716e47305cd7c2c3ff35c0e828b63ef2f6a8.tar.gz
spark-586e716e47305cd7c2c3ff35c0e828b63ef2f6a8.tar.bz2
spark-586e716e47305cd7c2c3ff35c0e828b63ef2f6a8.zip
Reservoir sampling implementation.
This is going to be used in https://issues.apache.org/jira/browse/SPARK-2568 Author: Reynold Xin <rxin@apache.org> Closes #1478 from rxin/reservoirSample and squashes the following commits: 17bcbf3 [Reynold Xin] Added seed. badf20d [Reynold Xin] Renamed the method. 6940010 [Reynold Xin] Reservoir sampling implementation.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala46
-rw-r--r--core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala21
2 files changed, 67 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index a79e3ee756..d10141b90e 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -17,9 +17,55 @@
package org.apache.spark.util.random
+import scala.reflect.ClassTag
+import scala.util.Random
+
private[spark] object SamplingUtils {
/**
+ * Reservoir sampling implementation that also returns the input size.
+ *
+ * @param input input size
+ * @param k reservoir size
+ * @param seed random seed
+ * @return (samples, input size)
+ */
+ def reservoirSampleAndCount[T: ClassTag](
+ input: Iterator[T],
+ k: Int,
+ seed: Long = Random.nextLong())
+ : (Array[T], Int) = {
+ val reservoir = new Array[T](k)
+ // Put the first k elements in the reservoir.
+ var i = 0
+ while (i < k && input.hasNext) {
+ val item = input.next()
+ reservoir(i) = item
+ i += 1
+ }
+
+ // If we have consumed all the elements, return them. Otherwise do the replacement.
+ if (i < k) {
+ // If input size < k, trim the array to return only an array of input size.
+ val trimReservoir = new Array[T](i)
+ System.arraycopy(reservoir, 0, trimReservoir, 0, i)
+ (trimReservoir, i)
+ } else {
+ // If input size > k, continue the sampling process.
+ val rand = new XORShiftRandom(seed)
+ while (input.hasNext) {
+ val item = input.next()
+ val replacementIndex = rand.nextInt(i)
+ if (replacementIndex < k) {
+ reservoir(replacementIndex) = item
+ }
+ i += 1
+ }
+ (reservoir, i)
+ }
+ }
+
+ /**
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
* the time.
*
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
index accfe2e9b7..73a9d029b0 100644
--- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -17,11 +17,32 @@
package org.apache.spark.util.random
+import scala.util.Random
+
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
import org.scalatest.FunSuite
class SamplingUtilsSuite extends FunSuite {
+ test("reservoirSampleAndCount") {
+ val input = Seq.fill(100)(Random.nextInt())
+
+ // input size < k
+ val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150)
+ assert(count1 === 100)
+ assert(input === sample1.toSeq)
+
+ // input size == k
+ val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100)
+ assert(count2 === 100)
+ assert(input === sample2.toSeq)
+
+ // input size > k
+ val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10)
+ assert(count3 === 100)
+ assert(sample3.length === 10)
+ }
+
test("computeFraction") {
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001