diff options
Diffstat (limited to 'core/src/test')
-rw-r--r-- | core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 35 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala | 46 |
2 files changed, 64 insertions, 17 deletions
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 2e2ccc5a18..e94a1e76d4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -505,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("takeSample") { - val data = sc.parallelize(1 to 100, 2) + val n = 1000000 + val data = sc.parallelize(1 to n, 2) for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) + val sample = data.takeSample(withReplacement=false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement=true, num=20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { - val sample = data.takeSample(withReplacement=true, num=100) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, num=n) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, n, seed) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements + val sample = data.takeSample(withReplacement=true, 2 * n, seed) + assert(sample.size === 2 * n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } } diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala new file mode 100644 index 0000000000..accfe2e9b7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} +import org.scalatest.FunSuite + +class SamplingUtilsSuite extends FunSuite { + + test("computeFraction") { + // test that the computed fraction guarantees enough data points + // in the sample with a failure rate <= 0.0001 + val n = 100000 + + for (s <- 1 to 15) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(20, 100, 1000)) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, false) + val binomial = new BinomialDistribution(n, frac) + assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") + } + } +} |