diff options
Diffstat (limited to 'python/pyspark/rddsampler.py')
-rw-r--r-- | python/pyspark/rddsampler.py | 31 |
1 files changed, 14 insertions, 17 deletions
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index aca2ef3b51..845a267e31 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -19,7 +19,7 @@ import sys import random class RDDSampler(object): - def __init__(self, withReplacement, fraction, seed): + def __init__(self, withReplacement, fraction, seed=None): try: import numpy self._use_numpy = True @@ -27,7 +27,7 @@ class RDDSampler(object): print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling." self._use_numpy = False - self._seed = seed + self._seed = seed if seed is not None else random.randint(0, sys.maxint) self._withReplacement = withReplacement self._fraction = fraction self._random = None @@ -38,17 +38,14 @@ class RDDSampler(object): if self._use_numpy: import numpy self._random = numpy.random.RandomState(self._seed) - for _ in range(0, split): - # discard the next few values in the sequence to have a - # different seed for the different splits - self._random.randint(sys.maxint) else: - import random - random.seed(self._seed) - for _ in range(0, split): - # discard the next few values in the sequence to have a - # different seed for the different splits - random.randint(0, sys.maxint) + self._random = random.Random(self._seed) + + for _ in range(0, split): + # discard the next few values in the sequence to have a + # different seed for the different splits + self._random.randint(0, sys.maxint) + self._split = split self._rand_initialized = True @@ -59,7 +56,7 @@ class RDDSampler(object): if self._use_numpy: return self._random.random_sample() else: - return random.uniform(0.0, 1.0) + return self._random.uniform(0.0, 1.0) def getPoissonSample(self, split, mean): if not self._rand_initialized or split != self._split: @@ -73,26 +70,26 @@ class RDDSampler(object): num_arrivals = 1 cur_time = 0.0 - cur_time += random.expovariate(mean) + cur_time += self._random.expovariate(mean) if cur_time > 1.0: return 0 while(cur_time <= 1.0): - cur_time += random.expovariate(mean) + cur_time += self._random.expovariate(mean) num_arrivals += 1 return (num_arrivals - 1) def shuffle(self, vals): - if self._random == None or split != self._split: + if self._random == None: self.initRandomGenerator(0) # this should only ever called on the master so # the split does not matter if self._use_numpy: self._random.shuffle(vals) else: - random.shuffle(vals, self._random) + self._random.shuffle(vals, self._random.random) def func(self, split, iterator): if self._withReplacement: |