aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rddsampler.py
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-07-24 23:42:08 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-24 23:42:08 -0700
commit2f75a4a30e1a3fdf384475b9660c6c43f093f68c (patch)
treeeb50e720cae6842bcb242d030adb27fba92c3f62 /python/pyspark/rddsampler.py
parent14174abd421318e71c16edd24224fd5094bdfed4 (diff)
downloadspark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.gz
spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.bz2
spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.zip
[SPARK-2656] Python version of stratified sampling
exact sample size not supported for now. Author: Doris Xin <doris.s.xin@gmail.com> Closes #1554 from dorx/pystratified and squashes the following commits: 4ba927a [Doris Xin] use rel diff (+- 50%) instead of abs diff (+- 50) bdc3f8b [Doris Xin] updated unit to check sample holistically 7713c7b [Doris Xin] Python version of stratified sampling
Diffstat (limited to 'python/pyspark/rddsampler.py')
-rw-r--r--python/pyspark/rddsampler.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 7ff1c316c7..2df000fdb0 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -19,8 +19,8 @@ import sys
import random
-class RDDSampler(object):
- def __init__(self, withReplacement, fraction, seed=None):
+class RDDSamplerBase(object):
+ def __init__(self, withReplacement, seed=None):
try:
import numpy
self._use_numpy = True
@@ -32,7 +32,6 @@ class RDDSampler(object):
self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
- self._fraction = fraction
self._random = None
self._split = None
self._rand_initialized = False
@@ -94,6 +93,12 @@ class RDDSampler(object):
else:
self._random.shuffle(vals, self._random.random)
+
+class RDDSampler(RDDSamplerBase):
+ def __init__(self, withReplacement, fraction, seed=None):
+ RDDSamplerBase.__init__(self, withReplacement, seed)
+ self._fraction = fraction
+
def func(self, split, iterator):
if self._withReplacement:
for obj in iterator:
@@ -107,3 +112,22 @@ class RDDSampler(object):
for obj in iterator:
if self.getUniformSample(split) <= self._fraction:
yield obj
+
+class RDDStratifiedSampler(RDDSamplerBase):
+ def __init__(self, withReplacement, fractions, seed=None):
+ RDDSamplerBase.__init__(self, withReplacement, seed)
+ self._fractions = fractions
+
+ def func(self, split, iterator):
+ if self._withReplacement:
+ for key, val in iterator:
+ # For large datasets, the expected number of occurrences of each element in
+ # a sample with replacement is Poisson(frac). We use that to get a count for
+ # each element.
+ count = self.getPoissonSample(split, mean=self._fractions[key])
+ for _ in range(0, count):
+ yield key, val
+ else:
+ for key, val in iterator:
+ if self.getUniformSample(split) <= self._fractions[key]:
+ yield key, val