aboutsummaryrefslogtreecommitdiff
path: root/core/src/test
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-06-12 19:44:27 -0700
committerXiangrui Meng <meng@databricks.com>2014-06-12 19:44:27 -0700
commit1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47 (patch)
treef99459c7412db3dd9479037c41e5a4055853ae09 /core/src/test
parent0154587ab71d1b864f97497dbb38bc52b87675be (diff)
downloadspark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.tar.gz
spark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.tar.bz2
spark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.zip
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Modified the takeSample method in RDD to use the ScaSRS sampling technique to improve performance. Added a private method that computes sampling rate > sample_size/total to ensure sufficient sample size with success rate >= 0.9999. Added a unit test for the private method to validate choice of sampling rate. Author: Doris Xin <doris.s.xin@gmail.com> Author: dorx <doris.s.xin@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #916 from dorx/takeSample and squashes the following commits: 5b061ae [Doris Xin] merge master 444e750 [Doris Xin] edge cases 3de882b [dorx] Merge pull request #2 from mengxr/SPARK-1939 82dde31 [Xiangrui Meng] update pyspark's takeSample 48d954d [Doris Xin] remove unused imports from RDDSuite fb1452f [Doris Xin] allowing num to be greater than count in all cases 1481b01 [Doris Xin] washing test tubes and making coffee dc699f3 [Doris Xin] give back imports removed by accident in rdd.py 64e445b [Doris Xin] logwarnning as soon as it enters the while loop 55518ed [Doris Xin] added TODO for logging in rdd.py eff89e2 [Doris Xin] addressed reviewer comments. ecab508 [Doris Xin] "fixed checkstyle violation 0a9b3e3 [Doris Xin] "reviewer comment addressed" f80f270 [Doris Xin] Merge branch 'master' into takeSample ae3ad04 [Doris Xin] fixed edge cases to prevent overflow 065ebcd [Doris Xin] Merge branch 'master' into takeSample 9bdd36e [Doris Xin] Check sample size and move computeFraction e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Diffstat (limited to 'core/src/test')
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala35
-rw-r--r--core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala46
2 files changed, 64 insertions, 17 deletions
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2e2ccc5a18..e94a1e76d4 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -505,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
test("takeSample") {
- val data = sc.parallelize(1 to 100, 2)
+ val n = 1000000
+ val data = sc.parallelize(1 to n, 2)
for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size === 20) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=false, 200, seed)
+ val sample = data.takeSample(withReplacement=false, 100, seed)
assert(sample.size === 100) // Got only 100 elements
assert(sample.toSet.size === 100) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
- val sample = data.takeSample(withReplacement=true, num=100)
- assert(sample.size === 100) // Got exactly 100 elements
+ val sample = data.takeSample(withReplacement=true, num=n)
+ assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=true, 100, seed)
- assert(sample.size === 100) // Got exactly 100 elements
+ val sample = data.takeSample(withReplacement=true, n, seed)
+ assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=true, 200, seed)
- assert(sample.size === 200) // Got exactly 200 elements
+ val sample = data.takeSample(withReplacement=true, 2 * n, seed)
+ assert(sample.size === 2 * n) // Got exactly 200 elements
// Chance of getting all distinct elements is still quite low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
new file mode 100644
index 0000000000..accfe2e9b7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
+import org.scalatest.FunSuite
+
+class SamplingUtilsSuite extends FunSuite {
+
+ test("computeFraction") {
+ // test that the computed fraction guarantees enough data points
+ // in the sample with a failure rate <= 0.0001
+ val n = 100000
+
+ for (s <- 1 to 15) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
+ val poisson = new PoissonDistribution(frac * n)
+ assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- List(20, 100, 1000)) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
+ val poisson = new PoissonDistribution(frac * n)
+ assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- List(1, 10, 100, 1000)) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
+ val binomial = new BinomialDistribution(n, frac)
+ assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
+ }
+ }
+}