aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-06-12 19:44:27 -0700
committerXiangrui Meng <meng@databricks.com>2014-06-12 19:44:27 -0700
commit1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47 (patch)
treef99459c7412db3dd9479037c41e5a4055853ae09 /core
parent0154587ab71d1b864f97497dbb38bc52b87675be (diff)
downloadspark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.tar.gz
spark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.tar.bz2
spark-1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47.zip
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Modified the takeSample method in RDD to use the ScaSRS sampling technique to improve performance. Added a private method that computes sampling rate > sample_size/total to ensure sufficient sample size with success rate >= 0.9999. Added a unit test for the private method to validate choice of sampling rate. Author: Doris Xin <doris.s.xin@gmail.com> Author: dorx <doris.s.xin@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #916 from dorx/takeSample and squashes the following commits: 5b061ae [Doris Xin] merge master 444e750 [Doris Xin] edge cases 3de882b [dorx] Merge pull request #2 from mengxr/SPARK-1939 82dde31 [Xiangrui Meng] update pyspark's takeSample 48d954d [Doris Xin] remove unused imports from RDDSuite fb1452f [Doris Xin] allowing num to be greater than count in all cases 1481b01 [Doris Xin] washing test tubes and making coffee dc699f3 [Doris Xin] give back imports removed by accident in rdd.py 64e445b [Doris Xin] logwarnning as soon as it enters the while loop 55518ed [Doris Xin] added TODO for logging in rdd.py eff89e2 [Doris Xin] addressed reviewer comments. ecab508 [Doris Xin] "fixed checkstyle violation 0a9b3e3 [Doris Xin] "reviewer comment addressed" f80f270 [Doris Xin] Merge branch 'master' into takeSample ae3ad04 [Doris Xin] fixed edge cases to prevent overflow 065ebcd [Doris Xin] Merge branch 'master' into takeSample 9bdd36e [Doris Xin] Check sample size and move computeFraction e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala55
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala35
-rw-r--r--core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala46
6 files changed, 156 insertions, 39 deletions
diff --git a/core/pom.xml b/core/pom.xml
index c3d6b00a44..be56911b9e 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -68,6 +68,11 @@
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-math3</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</dependency>
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index b6fc4b13ad..446f369c9e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -42,7 +42,7 @@ import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
-import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler}
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -378,46 +378,56 @@ abstract class RDD[T: ClassTag](
}.toArray
}
- def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
- {
- var fraction = 0.0
- var total = 0
- val multiplier = 3.0
- val initialCount = this.count()
- var maxSelected = 0
+ /**
+ * Return a fixed-size sampled subset of this RDD in an array
+ *
+ * @param withReplacement whether sampling is done with replacement
+ * @param num size of the returned sample
+ * @param seed seed for the random number generator
+ * @return sample of specified size in an array
+ */
+ def takeSample(withReplacement: Boolean,
+ num: Int,
+ seed: Long = Utils.random.nextLong): Array[T] = {
+ val numStDev = 10.0
if (num < 0) {
throw new IllegalArgumentException("Negative number of elements requested")
+ } else if (num == 0) {
+ return new Array[T](0)
}
+ val initialCount = this.count()
if (initialCount == 0) {
return new Array[T](0)
}
- if (initialCount > Integer.MAX_VALUE - 1) {
- maxSelected = Integer.MAX_VALUE - 1
- } else {
- maxSelected = initialCount.toInt
+ val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
+ if (num > maxSampleSize) {
+ throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
+ s"$numStDev * math.sqrt(Int.MaxValue)")
}
- if (num > initialCount && !withReplacement) {
- total = maxSelected
- fraction = multiplier * (maxSelected + 1) / initialCount
- } else {
- fraction = multiplier * (num + 1) / initialCount
- total = num
+ val rand = new Random(seed)
+ if (!withReplacement && num >= initialCount) {
+ return Utils.randomizeInPlace(this.collect(), rand)
}
- val rand = new Random(seed)
+ val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
+ withReplacement)
+
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for the initial size
- while (samples.length < total) {
+ var numIters = 0
+ while (samples.length < num) {
+ logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
+ numIters += 1
}
- Utils.randomizeInPlace(samples, rand).take(total)
+ Utils.randomizeInPlace(samples, rand).take(num)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 4dc8ada00a..247f10173f 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
}
/**
- * Return a sampler with is the complement of the range specified of the current sampler.
+ * Return a sampler that is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
new file mode 100644
index 0000000000..a79e3ee756
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+private[spark] object SamplingUtils {
+
+ /**
+ * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
+ * the time.
+ *
+ * How the sampling rate is determined:
+ * Let p = num / total, where num is the sample size and total is the total number of
+ * datapoints in the RDD. We're trying to compute q > p such that
+ * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
+ * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
+ * i.e. the failure rate of not having a sufficiently large sample < 0.0001.
+ * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
+ * num > 12, but we need a slightly larger q (9 empirically determined).
+ * - when sampling without replacement, we're drawing each datapoint with prob_i
+ * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
+ * rate, where success rate is defined the same as in sampling with replacement.
+ *
+ * @param sampleSizeLowerBound sample size
+ * @param total size of RDD
+ * @param withReplacement whether sampling with replacement
+ * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
+ */
+ def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
+ withReplacement: Boolean): Double = {
+ val fraction = sampleSizeLowerBound.toDouble / total
+ if (withReplacement) {
+ val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
+ fraction + numStDev * math.sqrt(fraction / total)
+ } else {
+ val delta = 1e-4
+ val gamma = - math.log(delta) / total
+ math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2e2ccc5a18..e94a1e76d4 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -505,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
test("takeSample") {
- val data = sc.parallelize(1 to 100, 2)
+ val n = 1000000
+ val data = sc.parallelize(1 to n, 2)
for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size === 20) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=false, 200, seed)
+ val sample = data.takeSample(withReplacement=false, 100, seed)
assert(sample.size === 100) // Got only 100 elements
assert(sample.toSet.size === 100) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
- val sample = data.takeSample(withReplacement=true, num=100)
- assert(sample.size === 100) // Got exactly 100 elements
+ val sample = data.takeSample(withReplacement=true, num=n)
+ assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=true, 100, seed)
- assert(sample.size === 100) // Got exactly 100 elements
+ val sample = data.takeSample(withReplacement=true, n, seed)
+ assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=true, 200, seed)
- assert(sample.size === 200) // Got exactly 200 elements
+ val sample = data.takeSample(withReplacement=true, 2 * n, seed)
+ assert(sample.size === 2 * n) // Got exactly 200 elements
// Chance of getting all distinct elements is still quite low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
new file mode 100644
index 0000000000..accfe2e9b7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
+import org.scalatest.FunSuite
+
+class SamplingUtilsSuite extends FunSuite {
+
+ test("computeFraction") {
+ // test that the computed fraction guarantees enough data points
+ // in the sample with a failure rate <= 0.0001
+ val n = 100000
+
+ for (s <- 1 to 15) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
+ val poisson = new PoissonDistribution(frac * n)
+ assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- List(20, 100, 1000)) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
+ val poisson = new PoissonDistribution(frac * n)
+ assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- List(1, 10, 100, 1000)) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
+ val binomial = new BinomialDistribution(n, frac)
+ assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
+ }
+ }
+}