aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rddsampler.py
diff options
context:
space:
mode:
authorArun Ramakrishnan <smartnut007@gmail.com>2014-04-24 17:27:16 -0700
committerMatei Zaharia <matei@databricks.com>2014-04-24 17:27:16 -0700
commit35e3d199f04fba3230625002a458d43b9578b2e8 (patch)
tree7e301a1585e3dc45cd1a42b8ce567b0aada57b4f /python/pyspark/rddsampler.py
parentf99af8529b6969986f0c3e03f6ff9b7bb9d53ece (diff)
downloadspark-35e3d199f04fba3230625002a458d43b9578b2e8.tar.gz
spark-35e3d199f04fba3230625002a458d43b9578b2e8.tar.bz2
spark-35e3d199f04fba3230625002a458d43b9578b2e8.zip
SPARK-1438 RDD.sample() make seed param optional
copying form previous pull request https://github.com/apache/spark/pull/462 Its probably better to let the underlying language implementation take care of the default . This was easier to do with python as the default value for seed in random and numpy random is None. In Scala/Java side it might mean propagating an Option or null(oh no!) down the chain until where the Random is constructed. But, looks like the convention in some other methods was to use System.nanoTime. So, followed that convention. Conflict with overloaded method in sql.SchemaRDD.sample which also defines default params. sample(fraction, withReplacement=false, seed=math.random) Scala does not allow more than one overloaded to have default params. I believe the author intended to override the RDD.sample method and not overload it. So, changed it. If backward compatible is important, 3 new method can be introduced (without default params) like this sample(fraction) sample(fraction, withReplacement) sample(fraction, withReplacement, seed) Added some tests for the scala RDD takeSample method. Author: Arun Ramakrishnan <smartnut007@gmail.com> This patch had conflicts when merged, resolved by Committer: Matei Zaharia <matei@databricks.com> Closes #477 from smartnut007/master and squashes the following commits: 07bb06e [Arun Ramakrishnan] SPARK-1438 fixing more space formatting issues b9ebfe2 [Arun Ramakrishnan] SPARK-1438 removing redundant import of random in python rddsampler 8d05b1a [Arun Ramakrishnan] SPARK-1438 RDD . Replace System.nanoTime with a Random generated number. python: use a separate instance of Random instead of seeding language api global Random instance. 69619c6 [Arun Ramakrishnan] SPARK-1438 fix spacing issue 0c247db [Arun Ramakrishnan] SPARK-1438 RDD language apis to support optional seed in RDD methods sample/takeSample
Diffstat (limited to 'python/pyspark/rddsampler.py')
-rw-r--r--python/pyspark/rddsampler.py31
1 files changed, 14 insertions, 17 deletions
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index aca2ef3b51..845a267e31 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -19,7 +19,7 @@ import sys
import random
class RDDSampler(object):
- def __init__(self, withReplacement, fraction, seed):
+ def __init__(self, withReplacement, fraction, seed=None):
try:
import numpy
self._use_numpy = True
@@ -27,7 +27,7 @@ class RDDSampler(object):
print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
self._use_numpy = False
- self._seed = seed
+ self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
self._fraction = fraction
self._random = None
@@ -38,17 +38,14 @@ class RDDSampler(object):
if self._use_numpy:
import numpy
self._random = numpy.random.RandomState(self._seed)
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- self._random.randint(sys.maxint)
else:
- import random
- random.seed(self._seed)
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- random.randint(0, sys.maxint)
+ self._random = random.Random(self._seed)
+
+ for _ in range(0, split):
+ # discard the next few values in the sequence to have a
+ # different seed for the different splits
+ self._random.randint(0, sys.maxint)
+
self._split = split
self._rand_initialized = True
@@ -59,7 +56,7 @@ class RDDSampler(object):
if self._use_numpy:
return self._random.random_sample()
else:
- return random.uniform(0.0, 1.0)
+ return self._random.uniform(0.0, 1.0)
def getPoissonSample(self, split, mean):
if not self._rand_initialized or split != self._split:
@@ -73,26 +70,26 @@ class RDDSampler(object):
num_arrivals = 1
cur_time = 0.0
- cur_time += random.expovariate(mean)
+ cur_time += self._random.expovariate(mean)
if cur_time > 1.0:
return 0
while(cur_time <= 1.0):
- cur_time += random.expovariate(mean)
+ cur_time += self._random.expovariate(mean)
num_arrivals += 1
return (num_arrivals - 1)
def shuffle(self, vals):
- if self._random == None or split != self._split:
+ if self._random == None:
self.initRandomGenerator(0) # this should only ever called on the master so
# the split does not matter
if self._use_numpy:
self._random.shuffle(vals)
else:
- random.shuffle(vals, self._random)
+ self._random.shuffle(vals, self._random.random)
def func(self, split, iterator):
if self._withReplacement: