diff options
author | Doris Xin <doris.s.xin@gmail.com> | 2014-07-24 23:42:08 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-07-24 23:42:08 -0700 |
commit | 2f75a4a30e1a3fdf384475b9660c6c43f093f68c (patch) | |
tree | eb50e720cae6842bcb242d030adb27fba92c3f62 /python/pyspark/rddsampler.py | |
parent | 14174abd421318e71c16edd24224fd5094bdfed4 (diff) | |
download | spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.gz spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.bz2 spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.zip |
[SPARK-2656] Python version of stratified sampling
exact sample size not supported for now.
Author: Doris Xin <doris.s.xin@gmail.com>
Closes #1554 from dorx/pystratified and squashes the following commits:
4ba927a [Doris Xin] use rel diff (+- 50%) instead of abs diff (+- 50)
bdc3f8b [Doris Xin] updated unit to check sample holistically
7713c7b [Doris Xin] Python version of stratified sampling
Diffstat (limited to 'python/pyspark/rddsampler.py')
-rw-r--r-- | python/pyspark/rddsampler.py | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 7ff1c316c7..2df000fdb0 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -19,8 +19,8 @@ import sys import random -class RDDSampler(object): - def __init__(self, withReplacement, fraction, seed=None): +class RDDSamplerBase(object): + def __init__(self, withReplacement, seed=None): try: import numpy self._use_numpy = True @@ -32,7 +32,6 @@ class RDDSampler(object): self._seed = seed if seed is not None else random.randint(0, sys.maxint) self._withReplacement = withReplacement - self._fraction = fraction self._random = None self._split = None self._rand_initialized = False @@ -94,6 +93,12 @@ class RDDSampler(object): else: self._random.shuffle(vals, self._random.random) + +class RDDSampler(RDDSamplerBase): + def __init__(self, withReplacement, fraction, seed=None): + RDDSamplerBase.__init__(self, withReplacement, seed) + self._fraction = fraction + def func(self, split, iterator): if self._withReplacement: for obj in iterator: @@ -107,3 +112,22 @@ class RDDSampler(object): for obj in iterator: if self.getUniformSample(split) <= self._fraction: yield obj + +class RDDStratifiedSampler(RDDSamplerBase): + def __init__(self, withReplacement, fractions, seed=None): + RDDSamplerBase.__init__(self, withReplacement, seed) + self._fractions = fractions + + def func(self, split, iterator): + if self._withReplacement: + for key, val in iterator: + # For large datasets, the expected number of occurrences of each element in + # a sample with replacement is Poisson(frac). We use that to get a count for + # each element. + count = self.getPoissonSample(split, mean=self._fractions[key]) + for _ in range(0, count): + yield key, val + else: + for key, val in iterator: + if self.getUniformSample(split) <= self._fractions[key]: + yield key, val |