aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-05 10:30:10 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-05 10:30:10 -0800
commit44751af9f8ec6a2b6ca49e5aee3e924c61afd3f7 (patch)
tree4c687b5b05ee69ab7b78d1a8257e94fc57515272 /python
parent1b282cdfda13e057b9cd85e1d71847d366fe7fcb (diff)
downloadspark-44751af9f8ec6a2b6ca49e5aee3e924c61afd3f7.tar.gz
spark-44751af9f8ec6a2b6ca49e5aee3e924c61afd3f7.tar.bz2
spark-44751af9f8ec6a2b6ca49e5aee3e924c61afd3f7.zip
[branch-1.1][SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample
Port #3010 to branch-1.1. Author: Xiangrui Meng <meng@databricks.com> Closes #3104 from mengxr/SPARK-4148-1.1 and squashes the following commits: 684c002 [Xiangrui Meng] apply SPARK-4148 to branch-1.1
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py3
-rw-r--r--python/pyspark/rddsampler.py11
-rw-r--r--python/pyspark/tests.py15
3 files changed, 20 insertions, 9 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2b47b6c18e..3f81550bbb 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -417,9 +417,6 @@ class RDD(object):
"""
Return a sampled subset of this RDD (relies on numpy and falls back
on default random generator if numpy is unavailable).
-
- >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
- [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 55e247da0e..a6e81067cf 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -40,14 +40,13 @@ class RDDSamplerBase(object):
def initRandomGenerator(self, split):
if self._use_numpy:
import numpy
- self._random = numpy.random.RandomState(self._seed)
+ self._random = numpy.random.RandomState(self._seed ^ split)
else:
- self._random = random.Random(self._seed)
+ self._random = random.Random(self._seed ^ split)
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- self._random.randint(0, sys.maxint)
+ # mixing because the initial seeds are close to each other
+ for _ in xrange(10):
+ self._random.randint(0, 1)
self._split = split
self._rand_initialized = True
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 8f0a351b6b..5cea1b03ea 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -470,6 +470,21 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))
+ def test_sample(self):
+ rdd = self.sc.parallelize(range(0, 100), 4)
+ wo = rdd.sample(False, 0.1, 2).collect()
+ wo_dup = rdd.sample(False, 0.1, 2).collect()
+ self.assertSetEqual(set(wo), set(wo_dup))
+ wr = rdd.sample(True, 0.2, 5).collect()
+ wr_dup = rdd.sample(True, 0.2, 5).collect()
+ self.assertSetEqual(set(wr), set(wr_dup))
+ wo_s10 = rdd.sample(False, 0.3, 10).collect()
+ wo_s20 = rdd.sample(False, 0.3, 20).collect()
+ self.assertNotEqual(set(wo_s10), set(wo_s20))
+ wr_s11 = rdd.sample(True, 0.4, 11).collect()
+ wr_s21 = rdd.sample(True, 0.4, 21).collect()
+ self.assertNotEqual(set(wr_s11), set(wr_s21))
+
class TestSQL(PySparkTestCase):