aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-03 12:24:24 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-03 12:24:24 -0800
commit3cca1962207745814b9d83e791713c91b659c36c (patch)
tree997330407c25dabe70d5792421ce6e831976300b /python/pyspark/tests.py
parent2aca97c7cfdefea8b6f9dbb88951e9acdfd606d9 (diff)
downloadspark-3cca1962207745814b9d83e791713c91b659c36c.tar.gz
spark-3cca1962207745814b9d83e791713c91b659c36c.tar.bz2
spark-3cca1962207745814b9d83e791713c91b659c36c.zip
[SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample
The current way of seed distribution makes the random sequences from partition i and i+1 offset by 1. ~~~ In [14]: import random In [15]: r1 = random.Random(10) In [16]: r1.randint(0, 1) Out[16]: 1 In [17]: r1.random() Out[17]: 0.4288890546751146 In [18]: r1.random() Out[18]: 0.5780913011344704 In [19]: r2 = random.Random(10) In [20]: r2.randint(0, 1) Out[20]: 1 In [21]: r2.randint(0, 1) Out[21]: 0 In [22]: r2.random() Out[22]: 0.5780913011344704 ~~~ Note: The new tests are not for this bug fix. Author: Xiangrui Meng <meng@databricks.com> Closes #3010 from mengxr/SPARK-4148 and squashes the following commits: 869ae4b [Xiangrui Meng] move tests tests.py c1bacd9 [Xiangrui Meng] fix seed distribution and add some tests for rdd.sample
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 37a128907b..253a471849 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -648,6 +648,21 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEquals(result.getNumPartitions(), 5)
self.assertEquals(result.count(), 3)
+ def test_sample(self):
+ rdd = self.sc.parallelize(range(0, 100), 4)
+ wo = rdd.sample(False, 0.1, 2).collect()
+ wo_dup = rdd.sample(False, 0.1, 2).collect()
+ self.assertSetEqual(set(wo), set(wo_dup))
+ wr = rdd.sample(True, 0.2, 5).collect()
+ wr_dup = rdd.sample(True, 0.2, 5).collect()
+ self.assertSetEqual(set(wr), set(wr_dup))
+ wo_s10 = rdd.sample(False, 0.3, 10).collect()
+ wo_s20 = rdd.sample(False, 0.3, 20).collect()
+ self.assertNotEqual(set(wo_s10), set(wo_s20))
+ wr_s11 = rdd.sample(True, 0.4, 11).collect()
+ wr_s21 = rdd.sample(True, 0.4, 21).collect()
+ self.assertNotEqual(set(wr_s11), set(wr_s21))
+
class ProfilerTests(PySparkTestCase):