aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-08-10 16:31:07 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-10 16:31:07 -0700
commitb715aa0c8090cd57158ead2a1b35632cb98a6277 (patch)
treedef1d42226f24cc572652d4dc4ae916c92ab5e69 /core
parent28dcbb531ae57dc50f15ad9df6c31022731669c9 (diff)
downloadspark-b715aa0c8090cd57158ead2a1b35632cb98a6277.tar.gz
spark-b715aa0c8090cd57158ead2a1b35632cb98a6277.tar.bz2
spark-b715aa0c8090cd57158ead2a1b35632cb98a6277.zip
[SPARK-2937] Separate out samplyByKeyExact as its own API in PairRDDFunction
To enable Python consistency and `Experimental` label of the `sampleByKeyExact` API. Author: Doris Xin <doris.s.xin@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #1866 from dorx/stratified and squashes the following commits: 0ad97b2 [Doris Xin] reviewer comments. 2948aae [Doris Xin] remove unrelated changes e990325 [Doris Xin] Merge branch 'master' into stratified 555a3f9 [Doris Xin] separate out sampleByKeyExact as its own API 616e55c [Doris Xin] merge master 245439e [Doris Xin] moved minSamplingRate to getUpperBound eaf5771 [Doris Xin] bug fixes. 17a381b [Doris Xin] fixed a merge issue and a failed unit ea7d27f [Doris Xin] merge master b223529 [Xiangrui Meng] use approx bounds for poisson fix poisson mean for waitlisting add unit tests for Java b3013a4 [Xiangrui Meng] move math3 back to test scope eecee5f [Doris Xin] Merge branch 'master' into stratified f4c21f3 [Doris Xin] Reviewer comments a10e68d [Doris Xin] style fix a2bf756 [Doris Xin] Merge branch 'master' into stratified 680b677 [Doris Xin] use mapPartitionWithIndex instead 9884a9f [Doris Xin] style fix bbfb8c9 [Doris Xin] Merge branch 'master' into stratified ee9d260 [Doris Xin] addressed reviewer comments 6b5b10b [Doris Xin] Merge branch 'master' into stratified 254e03c [Doris Xin] minor fixes and Java API. 4ad516b [Doris Xin] remove unused imports from PairRDDFunctions bd9dc6e [Doris Xin] unit bug and style violation fixed 1fe1cff [Doris Xin] Changed fractionByKey to a map to enable arg check 944a10c [Doris Xin] [SPARK-2145] Add lower bound on sampling rate 0214a76 [Doris Xin] cleanUp 90d94c0 [Doris Xin] merge master 9e74ab5 [Doris Xin] Separated out most of the logic in sampleByKey 7327611 [Doris Xin] merge master 50581fc [Doris Xin] added a TODO for logging in python 46f6c8c [Doris Xin] fixed the NPE caused by closures being cleaned before being passed into the aggregate function 7e1a481 [Doris Xin] changed the permission on SamplingUtil 1d413ce [Doris Xin] fixed checkstyle issues 9ee94ee [Doris Xin] [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala68
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala51
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java20
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala205
4 files changed, 216 insertions, 128 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 76d4193e96..feeb6c02ca 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -133,68 +133,62 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return a subset of this RDD sampled by key (via stratified sampling).
*
* Create a sample of this RDD using variable sampling rates for different keys as specified by
- * `fractions`, a key to sampling rate map.
- *
- * If `exact` is set to false, create the sample via simple random sampling, with one pass
- * over the RDD, to produce a sample of size that's approximately equal to the sum of
- * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
- * the RDD to create a sample size that's exactly equal to the sum of
+ * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
+ * RDD, to produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values.
*/
def sampleByKey(withReplacement: Boolean,
fractions: JMap[K, Double],
- exact: Boolean,
seed: Long): JavaPairRDD[K, V] =
- new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))
+ new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed))
/**
* Return a subset of this RDD sampled by key (via stratified sampling).
*
* Create a sample of this RDD using variable sampling rates for different keys as specified by
- * `fractions`, a key to sampling rate map.
- *
- * If `exact` is set to false, create the sample via simple random sampling, with one pass
- * over the RDD, to produce a sample of size that's approximately equal to the sum of
- * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
- * the RDD to create a sample size that's exactly equal to the sum of
+ * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
+ * RDD, to produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values.
*
- * Use Utils.random.nextLong as the default seed for the random number generator
+ * Use Utils.random.nextLong as the default seed for the random number generator.
*/
def sampleByKey(withReplacement: Boolean,
- fractions: JMap[K, Double],
- exact: Boolean): JavaPairRDD[K, V] =
- sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)
+ fractions: JMap[K, Double]): JavaPairRDD[K, V] =
+ sampleByKey(withReplacement, fractions, Utils.random.nextLong)
/**
- * Return a subset of this RDD sampled by key (via stratified sampling).
- *
- * Create a sample of this RDD using variable sampling rates for different keys as specified by
- * `fractions`, a key to sampling rate map.
+ * ::Experimental::
+ * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
+ * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
*
- * Produce a sample of size that's approximately equal to the sum of
- * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
- * simple random sampling.
+ * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
+ * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
+ * over all key values with a 99.99% confidence. When sampling without replacement, we need one
+ * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
+ * two additional passes.
*/
- def sampleByKey(withReplacement: Boolean,
+ @Experimental
+ def sampleByKeyExact(withReplacement: Boolean,
fractions: JMap[K, Double],
seed: Long): JavaPairRDD[K, V] =
- sampleByKey(withReplacement, fractions, false, seed)
+ new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed))
/**
- * Return a subset of this RDD sampled by key (via stratified sampling).
+ * ::Experimental::
+ * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
+ * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
*
- * Create a sample of this RDD using variable sampling rates for different keys as specified by
- * `fractions`, a key to sampling rate map.
- *
- * Produce a sample of size that's approximately equal to the sum of
- * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
- * simple random sampling.
+ * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
+ * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
+ * over all key values with a 99.99% confidence. When sampling without replacement, we need one
+ * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
+ * two additional passes.
*
- * Use Utils.random.nextLong as the default seed for the random number generator
+ * Use Utils.random.nextLong as the default seed for the random number generator.
*/
- def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
- sampleByKey(withReplacement, fractions, false, Utils.random.nextLong)
+ @Experimental
+ def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
+ sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong)
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 5dd6472b07..f6d9d12fe9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -197,33 +197,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return a subset of this RDD sampled by key (via stratified sampling).
*
* Create a sample of this RDD using variable sampling rates for different keys as specified by
- * `fractions`, a key to sampling rate map.
- *
- * If `exact` is set to false, create the sample via simple random sampling, with one pass
- * over the RDD, to produce a sample of size that's approximately equal to the sum of
- * math.ceil(numItems * samplingRate) over all key values; otherwise, use
- * additional passes over the RDD to create a sample size that's exactly equal to the sum of
- * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling
- * without replacement, we need one additional pass over the RDD to guarantee sample size;
- * when sampling with replacement, we need two additional passes.
+ * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
+ * RDD, to produce a sample of size that's approximately equal to the sum of
+ * math.ceil(numItems * samplingRate) over all key values.
*
* @param withReplacement whether to sample with or without replacement
* @param fractions map of specific keys to sampling rates
* @param seed seed for the random number generator
- * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key
* @return RDD containing the sampled subset
*/
def sampleByKey(withReplacement: Boolean,
fractions: Map[K, Double],
- exact: Boolean = false,
- seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
+ seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
+
+ require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
+
+ val samplingFunc = if (withReplacement) {
+ StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed)
+ } else {
+ StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed)
+ }
+ self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
+ }
+
+ /**
+ * ::Experimental::
+ * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
+ * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
+ *
+ * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
+ * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
+ * over all key values with a 99.99% confidence. When sampling without replacement, we need one
+ * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
+ * two additional passes.
+ *
+ * @param withReplacement whether to sample with or without replacement
+ * @param fractions map of specific keys to sampling rates
+ * @param seed seed for the random number generator
+ * @return RDD containing the sampled subset
+ */
+ @Experimental
+ def sampleByKeyExact(withReplacement: Boolean,
+ fractions: Map[K, Double],
+ seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
val samplingFunc = if (withReplacement) {
- StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed)
+ StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed)
} else {
- StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed)
+ StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed)
}
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 56150caa5d..e1c13de04a 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1239,12 +1239,28 @@ public class JavaAPISuite implements Serializable {
Assert.assertTrue(worCounts.size() == 2);
Assert.assertTrue(worCounts.get(0) > 0);
Assert.assertTrue(worCounts.get(1) > 0);
- JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKey(true, fractions, true, 1L);
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void sampleByKeyExact() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3);
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
+ new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) {
+ return new Tuple2<Integer, Integer>(i % 2, 1);
+ }
+ });
+ Map<Integer, Object> fractions = Maps.newHashMap();
+ fractions.put(0, 0.5);
+ fractions.put(1, 1.0);
+ JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKeyExact(true, fractions, 1L);
Map<Integer, Long> wrExactCounts = (Map<Integer, Long>) (Object) wrExact.countByKey();
Assert.assertTrue(wrExactCounts.size() == 2);
Assert.assertTrue(wrExactCounts.get(0) == 2);
Assert.assertTrue(wrExactCounts.get(1) == 4);
- JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKey(false, fractions, true, 1L);
+ JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKeyExact(false, fractions, 1L);
Map<Integer, Long> worExactCounts = (Map<Integer, Long>) (Object) worExact.countByKey();
Assert.assertTrue(worExactCounts.size() == 2);
Assert.assertTrue(worExactCounts.get(0) == 2);
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 4f49d4a1d4..63d3ddb4af 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -84,118 +84,81 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
}
test("sampleByKey") {
- def stratifier (fractionPositive: Double) = {
- (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
- }
- def checkSize(exact: Boolean,
- withReplacement: Boolean,
- expected: Long,
- actual: Long,
- p: Double): Boolean = {
- if (exact) {
- return expected == actual
- }
- val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p))
- // Very forgiving margin since we're dealing with very small sample sizes most of the time
- math.abs(actual - expected) <= 6 * stdev
+ val defaultSeed = 1L
+
+ // vary RDD size
+ for (n <- List(100, 1000, 1000000)) {
+ val data = sc.parallelize(1 to n, 2)
+ val fractionPositive = 0.3
+ val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
+ val samplingRate = 0.1
+ StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n)
}
- // Without replacement validation
- def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)],
- exact: Boolean,
- samplingRate: Double,
- seed: Long,
- n: Long) = {
- val expectedSampleSize = stratifiedData.countByKey()
- .mapValues(count => math.ceil(count * samplingRate).toInt)
- val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
- val sample = stratifiedData.sampleByKey(false, fractions, exact, seed)
- val sampleCounts = sample.countByKey()
- val takeSample = sample.collect()
- sampleCounts.foreach { case(k, v) =>
- assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) }
- assert(takeSample.size === takeSample.toSet.size)
- takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") }
+ // vary fractionPositive
+ for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) {
+ val n = 100
+ val data = sc.parallelize(1 to n, 2)
+ val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
+ val samplingRate = 0.1
+ StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n)
}
- // With replacement validation
- def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)],
- exact: Boolean,
- samplingRate: Double,
- seed: Long,
- n: Long) = {
- val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
- math.ceil(count * samplingRate).toInt)
- val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
- val sample = stratifiedData.sampleByKey(true, fractions, exact, seed)
- val sampleCounts = sample.countByKey()
- val takeSample = sample.collect()
- sampleCounts.foreach { case(k, v) =>
- assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) }
- val groupedByKey = takeSample.groupBy(_._1)
- for ((key, v) <- groupedByKey) {
- if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) {
- // sample large enough for there to be repeats with high likelihood
- assert(v.toSet.size < expectedSampleSize(key))
- } else {
- if (exact) {
- assert(v.toSet.size <= expectedSampleSize(key))
- } else {
- assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate))
- }
- }
- }
- takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") }
+ // Use the same data for the rest of the tests
+ val fractionPositive = 0.3
+ val n = 100
+ val data = sc.parallelize(1 to n, 2)
+ val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
+
+ // vary seed
+ for (seed <- defaultSeed to defaultSeed + 5L) {
+ val samplingRate = 0.1
+ StratifiedAuxiliary.testSample(stratifiedData, samplingRate, seed, n)
}
- def checkAllCombos(stratifiedData: RDD[(String, Int)],
- samplingRate: Double,
- seed: Long,
- n: Long) = {
- takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n)
- takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n)
- takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n)
- takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n)
+ // vary sampling rate
+ for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) {
+ StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n)
}
+ }
+ test("sampleByKeyExact") {
val defaultSeed = 1L
// vary RDD size
for (n <- List(100, 1000, 1000000)) {
val data = sc.parallelize(1 to n, 2)
val fractionPositive = 0.3
- val stratifiedData = data.keyBy(stratifier(fractionPositive))
-
+ val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
val samplingRate = 0.1
- checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+ StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n)
}
// vary fractionPositive
for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) {
val n = 100
val data = sc.parallelize(1 to n, 2)
- val stratifiedData = data.keyBy(stratifier(fractionPositive))
-
+ val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
val samplingRate = 0.1
- checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+ StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n)
}
// Use the same data for the rest of the tests
val fractionPositive = 0.3
val n = 100
val data = sc.parallelize(1 to n, 2)
- val stratifiedData = data.keyBy(stratifier(fractionPositive))
+ val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
// vary seed
for (seed <- defaultSeed to defaultSeed + 5L) {
val samplingRate = 0.1
- checkAllCombos(stratifiedData, samplingRate, seed, n)
+ StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, seed, n)
}
// vary sampling rate
for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) {
- checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+ StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n)
}
}
@@ -556,6 +519,98 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
intercept[IllegalArgumentException] {shuffled.lookup(-1)}
}
+ private object StratifiedAuxiliary {
+ def stratifier (fractionPositive: Double) = {
+ (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
+ }
+
+ def checkSize(exact: Boolean,
+ withReplacement: Boolean,
+ expected: Long,
+ actual: Long,
+ p: Double): Boolean = {
+ if (exact) {
+ return expected == actual
+ }
+ val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p))
+ // Very forgiving margin since we're dealing with very small sample sizes most of the time
+ math.abs(actual - expected) <= 6 * stdev
+ }
+
+ def testSampleExact(stratifiedData: RDD[(String, Int)],
+ samplingRate: Double,
+ seed: Long,
+ n: Long) = {
+ testBernoulli(stratifiedData, true, samplingRate, seed, n)
+ testPoisson(stratifiedData, true, samplingRate, seed, n)
+ }
+
+ def testSample(stratifiedData: RDD[(String, Int)],
+ samplingRate: Double,
+ seed: Long,
+ n: Long) = {
+ testBernoulli(stratifiedData, false, samplingRate, seed, n)
+ testPoisson(stratifiedData, false, samplingRate, seed, n)
+ }
+
+ // Without replacement validation
+ def testBernoulli(stratifiedData: RDD[(String, Int)],
+ exact: Boolean,
+ samplingRate: Double,
+ seed: Long,
+ n: Long) = {
+ val expectedSampleSize = stratifiedData.countByKey()
+ .mapValues(count => math.ceil(count * samplingRate).toInt)
+ val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
+ val sample = if (exact) {
+ stratifiedData.sampleByKeyExact(false, fractions, seed)
+ } else {
+ stratifiedData.sampleByKey(false, fractions, seed)
+ }
+ val sampleCounts = sample.countByKey()
+ val takeSample = sample.collect()
+ sampleCounts.foreach { case(k, v) =>
+ assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) }
+ assert(takeSample.size === takeSample.toSet.size)
+ takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") }
+ }
+
+ // With replacement validation
+ def testPoisson(stratifiedData: RDD[(String, Int)],
+ exact: Boolean,
+ samplingRate: Double,
+ seed: Long,
+ n: Long) = {
+ val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
+ math.ceil(count * samplingRate).toInt)
+ val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
+ val sample = if (exact) {
+ stratifiedData.sampleByKeyExact(true, fractions, seed)
+ } else {
+ stratifiedData.sampleByKey(true, fractions, seed)
+ }
+ val sampleCounts = sample.countByKey()
+ val takeSample = sample.collect()
+ sampleCounts.foreach { case (k, v) =>
+ assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate))
+ }
+ val groupedByKey = takeSample.groupBy(_._1)
+ for ((key, v) <- groupedByKey) {
+ if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) {
+ // sample large enough for there to be repeats with high likelihood
+ assert(v.toSet.size < expectedSampleSize(key))
+ } else {
+ if (exact) {
+ assert(v.toSet.size <= expectedSampleSize(key))
+ } else {
+ assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate))
+ }
+ }
+ }
+ takeSample.foreach(x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]"))
+ }
+ }
+
}
/*