aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py13
-rw-r--r--python/pyspark/rddsampler.py31
2 files changed, 20 insertions, 24 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 91fc7e637e..d73ab7006e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -30,6 +30,7 @@ from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
import heapq
+from random import Random
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -332,7 +333,7 @@ class RDD(object):
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
- def sample(self, withReplacement, fraction, seed):
+ def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this RDD (relies on numpy and falls back
on default random generator if numpy is unavailable).
@@ -344,7 +345,7 @@ class RDD(object):
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
# this is ported from scala/spark/RDD.scala
- def takeSample(self, withReplacement, num, seed):
+ def takeSample(self, withReplacement, num, seed=None):
"""
Return a fixed-size sampled subset of this RDD (currently requires numpy).
@@ -381,13 +382,11 @@ class RDD(object):
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
+ rand = Random(seed)
while len(samples) < total:
- if seed > sys.maxint - 2:
- seed = -1
- seed += 1
- samples = self.sample(withReplacement, fraction, seed).collect()
+ samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()
- sampler = RDDSampler(withReplacement, fraction, seed+1)
+ sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
sampler.shuffle(samples)
return samples[0:total]
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index aca2ef3b51..845a267e31 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -19,7 +19,7 @@ import sys
import random
class RDDSampler(object):
- def __init__(self, withReplacement, fraction, seed):
+ def __init__(self, withReplacement, fraction, seed=None):
try:
import numpy
self._use_numpy = True
@@ -27,7 +27,7 @@ class RDDSampler(object):
print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
self._use_numpy = False
- self._seed = seed
+ self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
self._fraction = fraction
self._random = None
@@ -38,17 +38,14 @@ class RDDSampler(object):
if self._use_numpy:
import numpy
self._random = numpy.random.RandomState(self._seed)
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- self._random.randint(sys.maxint)
else:
- import random
- random.seed(self._seed)
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- random.randint(0, sys.maxint)
+ self._random = random.Random(self._seed)
+
+ for _ in range(0, split):
+ # discard the next few values in the sequence to have a
+ # different seed for the different splits
+ self._random.randint(0, sys.maxint)
+
self._split = split
self._rand_initialized = True
@@ -59,7 +56,7 @@ class RDDSampler(object):
if self._use_numpy:
return self._random.random_sample()
else:
- return random.uniform(0.0, 1.0)
+ return self._random.uniform(0.0, 1.0)
def getPoissonSample(self, split, mean):
if not self._rand_initialized or split != self._split:
@@ -73,26 +70,26 @@ class RDDSampler(object):
num_arrivals = 1
cur_time = 0.0
- cur_time += random.expovariate(mean)
+ cur_time += self._random.expovariate(mean)
if cur_time > 1.0:
return 0
while(cur_time <= 1.0):
- cur_time += random.expovariate(mean)
+ cur_time += self._random.expovariate(mean)
num_arrivals += 1
return (num_arrivals - 1)
def shuffle(self, vals):
- if self._random == None or split != self._split:
+ if self._random == None:
self.initRandomGenerator(0) # this should only ever called on the master so
# the split does not matter
if self._use_numpy:
self._random.shuffle(vals)
else:
- random.shuffle(vals, self._random)
+ self._random.shuffle(vals, self._random.random)
def func(self, split, iterator):
if self._withReplacement: