diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-11-03 12:24:24 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-03 12:24:47 -0800 |
commit | a68321400c1068449698d03cebd0fbf648627133 (patch) | |
tree | 58a5b8ef2aa5a8e30ee89bbfee8e207ac463e430 /python/pyspark/tests.py | |
parent | 76386e1a23c55a58c0aeea67820aab2bac71b24b (diff) | |
download | spark-a68321400c1068449698d03cebd0fbf648627133.tar.gz spark-a68321400c1068449698d03cebd0fbf648627133.tar.bz2 spark-a68321400c1068449698d03cebd0fbf648627133.zip |
[SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample
The current way of seed distribution makes the random sequences from partition i and i+1 offset by 1.
~~~
In [14]: import random
In [15]: r1 = random.Random(10)
In [16]: r1.randint(0, 1)
Out[16]: 1
In [17]: r1.random()
Out[17]: 0.4288890546751146
In [18]: r1.random()
Out[18]: 0.5780913011344704
In [19]: r2 = random.Random(10)
In [20]: r2.randint(0, 1)
Out[20]: 1
In [21]: r2.randint(0, 1)
Out[21]: 0
In [22]: r2.random()
Out[22]: 0.5780913011344704
~~~
Note: The new tests are not for this bug fix.
Author: Xiangrui Meng <meng@databricks.com>
Closes #3010 from mengxr/SPARK-4148 and squashes the following commits:
869ae4b [Xiangrui Meng] move tests tests.py
c1bacd9 [Xiangrui Meng] fix seed distribution and add some tests for rdd.sample
(cherry picked from commit 3cca1962207745814b9d83e791713c91b659c36c)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r-- | python/pyspark/tests.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 37a128907b..253a471849 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -648,6 +648,21 @@ class RDDTests(ReusedPySparkTestCase): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + class ProfilerTests(PySparkTestCase): |