aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 8f0a351b6b..5cea1b03ea 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -470,6 +470,21 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))
+ def test_sample(self):
+ rdd = self.sc.parallelize(range(0, 100), 4)
+ wo = rdd.sample(False, 0.1, 2).collect()
+ wo_dup = rdd.sample(False, 0.1, 2).collect()
+ self.assertSetEqual(set(wo), set(wo_dup))
+ wr = rdd.sample(True, 0.2, 5).collect()
+ wr_dup = rdd.sample(True, 0.2, 5).collect()
+ self.assertSetEqual(set(wr), set(wr_dup))
+ wo_s10 = rdd.sample(False, 0.3, 10).collect()
+ wo_s20 = rdd.sample(False, 0.3, 20).collect()
+ self.assertNotEqual(set(wo_s10), set(wo_s20))
+ wr_s11 = rdd.sample(True, 0.4, 11).collect()
+ wr_s21 = rdd.sample(True, 0.4, 21).collect()
+ self.assertNotEqual(set(wr_s11), set(wr_s21))
+
class TestSQL(PySparkTestCase):