aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rddsampler.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rddsampler.py')
-rw-r--r--python/pyspark/rddsampler.py31
1 files changed, 14 insertions, 17 deletions
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index aca2ef3b51..845a267e31 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -19,7 +19,7 @@ import sys
import random
class RDDSampler(object):
- def __init__(self, withReplacement, fraction, seed):
+ def __init__(self, withReplacement, fraction, seed=None):
try:
import numpy
self._use_numpy = True
@@ -27,7 +27,7 @@ class RDDSampler(object):
print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
self._use_numpy = False
- self._seed = seed
+ self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
self._fraction = fraction
self._random = None
@@ -38,17 +38,14 @@ class RDDSampler(object):
if self._use_numpy:
import numpy
self._random = numpy.random.RandomState(self._seed)
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- self._random.randint(sys.maxint)
else:
- import random
- random.seed(self._seed)
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- random.randint(0, sys.maxint)
+ self._random = random.Random(self._seed)
+
+ for _ in range(0, split):
+ # discard the next few values in the sequence to have a
+ # different seed for the different splits
+ self._random.randint(0, sys.maxint)
+
self._split = split
self._rand_initialized = True
@@ -59,7 +56,7 @@ class RDDSampler(object):
if self._use_numpy:
return self._random.random_sample()
else:
- return random.uniform(0.0, 1.0)
+ return self._random.uniform(0.0, 1.0)
def getPoissonSample(self, split, mean):
if not self._rand_initialized or split != self._split:
@@ -73,26 +70,26 @@ class RDDSampler(object):
num_arrivals = 1
cur_time = 0.0
- cur_time += random.expovariate(mean)
+ cur_time += self._random.expovariate(mean)
if cur_time > 1.0:
return 0
while(cur_time <= 1.0):
- cur_time += random.expovariate(mean)
+ cur_time += self._random.expovariate(mean)
num_arrivals += 1
return (num_arrivals - 1)
def shuffle(self, vals):
- if self._random == None or split != self._split:
+ if self._random == None:
self.initRandomGenerator(0) # this should only ever called on the master so
# the split does not matter
if self._use_numpy:
self._random.shuffle(vals)
else:
- random.shuffle(vals, self._random)
+ self._random.shuffle(vals, self._random.random)
def func(self, split, iterator):
if self._withReplacement: