aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorfreeman <the.freeman.lab@gmail.com>2014-10-22 09:33:12 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-22 09:33:12 -0700
commit97cf19f64e924569892e0a0417de19329855f4af (patch)
tree3252c2f54d296d8e79f97b2094a81cb5cf9cf0be /python
parentf05e09b4c95d799bdda3c3ff7fb76a4cd656415d (diff)
downloadspark-97cf19f64e924569892e0a0417de19329855f4af.tar.gz
spark-97cf19f64e924569892e0a0417de19329855f4af.tar.bz2
spark-97cf19f64e924569892e0a0417de19329855f4af.zip
Fix for sampling error in NumPy v1.9 [SPARK-3995][PYSPARK]
Change maximum value for default seed during RDD sampling so that it is strictly less than 2 ** 32. This prevents a bug in the most recent version of NumPy, which cannot accept random seeds above this bound. Adds an extra test that uses the default seed (instead of setting it manually, as in the docstrings). mengxr Author: freeman <the.freeman.lab@gmail.com> Closes #2889 from freeman-lab/pyspark-sampling and squashes the following commits: dc385ef [freeman] Change maximum value for default seed
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rddsampler.py4
-rw-r--r--python/pyspark/tests.py6
2 files changed, 8 insertions, 2 deletions
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 55e247da0e..528a181e89 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -31,7 +31,7 @@ class RDDSamplerBase(object):
"Falling back to default random generator for sampling.")
self._use_numpy = False
- self._seed = seed if seed is not None else random.randint(0, sys.maxint)
+ self._seed = seed if seed is not None else random.randint(0, 2 ** 32 - 1)
self._withReplacement = withReplacement
self._random = None
self._split = None
@@ -47,7 +47,7 @@ class RDDSamplerBase(object):
for _ in range(0, split):
# discard the next few values in the sequence to have a
# different seed for the different splits
- self._random.randint(0, sys.maxint)
+ self._random.randint(0, 2 ** 32 - 1)
self._split = split
self._rand_initialized = True
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f5ccf31abb..1a8e4150e6 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -433,6 +433,12 @@ class RDDTests(ReusedPySparkTestCase):
os.unlink(tempFile.name)
self.assertRaises(Exception, lambda: filtered_data.count())
+ def test_sampling_default_seed(self):
+ # Test for SPARK-3995 (default seed setting)
+ data = self.sc.parallelize(range(1000), 1)
+ subset = data.takeSample(False, 10)
+ self.assertEqual(len(subset), 10)
+
def testAggregateByKey(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)