diff options
author | Reynold Xin <rxin@apache.org> | 2014-07-18 12:41:50 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-07-18 12:41:50 -0700 |
commit | 586e716e47305cd7c2c3ff35c0e828b63ef2f6a8 (patch) | |
tree | 0ebe507c73e48c12a4f7863b49e177c17c539098 /core | |
parent | 7f87ab98138d00723e007471f1a7f506650978cb (diff) | |
download | spark-586e716e47305cd7c2c3ff35c0e828b63ef2f6a8.tar.gz spark-586e716e47305cd7c2c3ff35c0e828b63ef2f6a8.tar.bz2 spark-586e716e47305cd7c2c3ff35c0e828b63ef2f6a8.zip |
Reservoir sampling implementation.
This is going to be used in https://issues.apache.org/jira/browse/SPARK-2568
Author: Reynold Xin <rxin@apache.org>
Closes #1478 from rxin/reservoirSample and squashes the following commits:
17bcbf3 [Reynold Xin] Added seed.
badf20d [Reynold Xin] Renamed the method.
6940010 [Reynold Xin] Reservoir sampling implementation.
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala | 46 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala | 21 |
2 files changed, 67 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index a79e3ee756..d10141b90e 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -17,9 +17,55 @@ package org.apache.spark.util.random +import scala.reflect.ClassTag +import scala.util.Random + private[spark] object SamplingUtils { /** + * Reservoir sampling implementation that also returns the input size. + * + * @param input input size + * @param k reservoir size + * @param seed random seed + * @return (samples, input size) + */ + def reservoirSampleAndCount[T: ClassTag]( + input: Iterator[T], + k: Int, + seed: Long = Random.nextLong()) + : (Array[T], Int) = { + val reservoir = new Array[T](k) + // Put the first k elements in the reservoir. + var i = 0 + while (i < k && input.hasNext) { + val item = input.next() + reservoir(i) = item + i += 1 + } + + // If we have consumed all the elements, return them. Otherwise do the replacement. + if (i < k) { + // If input size < k, trim the array to return only an array of input size. + val trimReservoir = new Array[T](i) + System.arraycopy(reservoir, 0, trimReservoir, 0, i) + (trimReservoir, i) + } else { + // If input size > k, continue the sampling process. + val rand = new XORShiftRandom(seed) + while (input.hasNext) { + val item = input.next() + val replacementIndex = rand.nextInt(i) + if (replacementIndex < k) { + reservoir(replacementIndex) = item + } + i += 1 + } + (reservoir, i) + } + } + + /** * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of * the time. * diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index accfe2e9b7..73a9d029b0 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -17,11 +17,32 @@ package org.apache.spark.util.random +import scala.util.Random + import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.scalatest.FunSuite class SamplingUtilsSuite extends FunSuite { + test("reservoirSampleAndCount") { + val input = Seq.fill(100)(Random.nextInt()) + + // input size < k + val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150) + assert(count1 === 100) + assert(input === sample1.toSeq) + + // input size == k + val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100) + assert(count2 === 100) + assert(input === sample2.toSeq) + + // input size > k + val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10) + assert(count3 === 100) + assert(sample3.length === 10) + } + test("computeFraction") { // test that the computed fraction guarantees enough data points // in the sample with a failure rate <= 0.0001 |