diff options
author | Doris Xin <doris.s.xin@gmail.com> | 2014-06-12 19:44:27 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-06-12 19:44:27 -0700 |
commit | 1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47 (patch) | |
tree | f99459c7412db3dd9479037c41e5a4055853ae09 /core | |
parent | 0154587ab71d1b864f97497dbb38bc52b87675be (diff) | |
download | spark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.tar.gz spark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.tar.bz2 spark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.zip |
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Modified the takeSample method in RDD to use the ScaSRS sampling technique to improve performance. Added a private method that computes sampling rate > sample_size/total to ensure sufficient sample size with success rate >= 0.9999. Added a unit test for the private method to validate choice of sampling rate.
Author: Doris Xin <doris.s.xin@gmail.com>
Author: dorx <doris.s.xin@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>
Closes #916 from dorx/takeSample and squashes the following commits:
5b061ae [Doris Xin] merge master
444e750 [Doris Xin] edge cases
3de882b [dorx] Merge pull request #2 from mengxr/SPARK-1939
82dde31 [Xiangrui Meng] update pyspark's takeSample
48d954d [Doris Xin] remove unused imports from RDDSuite
fb1452f [Doris Xin] allowing num to be greater than count in all cases
1481b01 [Doris Xin] washing test tubes and making coffee
dc699f3 [Doris Xin] give back imports removed by accident in rdd.py
64e445b [Doris Xin] logwarnning as soon as it enters the while loop
55518ed [Doris Xin] added TODO for logging in rdd.py
eff89e2 [Doris Xin] addressed reviewer comments.
ecab508 [Doris Xin] "fixed checkstyle violation
0a9b3e3 [Doris Xin] "reviewer comment addressed"
f80f270 [Doris Xin] Merge branch 'master' into takeSample
ae3ad04 [Doris Xin] fixed edge cases to prevent overflow
065ebcd [Doris Xin] Merge branch 'master' into takeSample
9bdd36e [Doris Xin] Check sample size and move computeFraction
e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample
7cab53a [Doris Xin] fixed import bug in rdd.py
ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD
1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Diffstat (limited to 'core')
6 files changed, 156 insertions, 39 deletions
diff --git a/core/pom.xml b/core/pom.xml index c3d6b00a44..be56911b9e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -68,6 +68,11 @@ <artifactId>commons-lang3</artifactId> </dependency> <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-math3</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>com.google.code.findbugs</groupId> <artifactId>jsr305</artifactId> </dependency> diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b6fc4b13ad..446f369c9e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -42,7 +42,7 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler} +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -378,46 +378,56 @@ abstract class RDD[T: ClassTag]( }.toArray } - def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = - { - var fraction = 0.0 - var total = 0 - val multiplier = 3.0 - val initialCount = this.count() - var maxSelected = 0 + /** + * Return a fixed-size sampled subset of this RDD in an array + * + * @param withReplacement whether sampling is done with replacement + * @param num size of the returned sample + * @param seed seed for the random number generator + * @return sample of specified size in an array + */ + def takeSample(withReplacement: Boolean, + num: Int, + seed: Long = Utils.random.nextLong): Array[T] = { + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") + } else if (num == 0) { + return new Array[T](0) } + val initialCount = this.count() if (initialCount == 0) { return new Array[T](0) } - if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - 1 - } else { - maxSelected = initialCount.toInt + val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt + if (num > maxSampleSize) { + throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") } - if (num > initialCount && !withReplacement) { - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - } else { - fraction = multiplier * (num + 1) / initialCount - total = num + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + return Utils.randomizeInPlace(this.collect(), rand) } - val rand = new Random(seed) + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; // this shouldn't happen often because we use a big multiplier for the initial size - while (samples.length < total) { + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 } - Utils.randomizeInPlace(samples, rand).take(total) + Utils.randomizeInPlace(samples, rand).take(num) } /** diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 4dc8ada00a..247f10173f 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } /** - * Return a sampler with is the complement of the range specified of the current sampler. + * Return a sampler that is the complement of the range specified of the current sampler. */ def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala new file mode 100644 index 0000000000..a79e3ee756 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +private[spark] object SamplingUtils { + + /** + * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of + * the time. + * + * How the sampling rate is determined: + * Let p = num / total, where num is the sample size and total is the total number of + * datapoints in the RDD. We're trying to compute q > p such that + * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), + * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), + * i.e. the failure rate of not having a sufficiently large sample < 0.0001. + * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for + * num > 12, but we need a slightly larger q (9 empirically determined). + * - when sampling without replacement, we're drawing each datapoint with prob_i + * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success + * rate, where success rate is defined the same as in sampling with replacement. + * + * @param sampleSizeLowerBound sample size + * @param total size of RDD + * @param withReplacement whether sampling with replacement + * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate + */ + def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, + withReplacement: Boolean): Double = { + val fraction = sampleSizeLowerBound.toDouble / total + if (withReplacement) { + val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 + fraction + numStDev * math.sqrt(fraction / total) + } else { + val delta = 1e-4 + val gamma = - math.log(delta) / total + math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 2e2ccc5a18..e94a1e76d4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -505,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("takeSample") { - val data = sc.parallelize(1 to 100, 2) + val n = 1000000 + val data = sc.parallelize(1 to n, 2) for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) + val sample = data.takeSample(withReplacement=false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement=true, num=20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { - val sample = data.takeSample(withReplacement=true, num=100) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, num=n) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, n, seed) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements + val sample = data.takeSample(withReplacement=true, 2 * n, seed) + assert(sample.size === 2 * n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } } diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala new file mode 100644 index 0000000000..accfe2e9b7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} +import org.scalatest.FunSuite + +class SamplingUtilsSuite extends FunSuite { + + test("computeFraction") { + // test that the computed fraction guarantees enough data points + // in the sample with a failure rate <= 0.0001 + val n = 100000 + + for (s <- 1 to 15) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(20, 100, 1000)) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, false) + val binomial = new BinomialDistribution(n, frac) + assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") + } + } +} |