diff options
author | Davies Liu <davies@databricks.com> | 2014-11-18 16:37:35 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-18 16:37:35 -0800 |
commit | 7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf (patch) | |
tree | 1c79469770ebd75dc3eea0fd165bc4beede27ab8 /python/pyspark/rddsampler.py | |
parent | bb46046154a438df4db30a0e1fd557bd3399ee7b (diff) | |
download | spark-7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf.tar.gz spark-7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf.tar.bz2 spark-7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf.zip |
[SPARK-4327] [PySpark] Python API for RDD.randomSplit()
```
pyspark.RDD.randomSplit(self, weights, seed=None)
Randomly splits this RDD with the provided weights.
:param weights: weights for splits, will be normalized if they don't sum to 1
:param seed: random seed
:return: split RDDs in an list
>>> rdd = sc.parallelize(range(10), 1)
>>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11)
>>> rdd1.collect()
[3, 6]
>>> rdd2.collect()
[0, 5, 7]
>>> rdd3.collect()
[1, 2, 4, 8, 9]
```
Author: Davies Liu <davies@databricks.com>
Closes #3193 from davies/randomSplit and squashes the following commits:
78bf997 [Davies Liu] fix tests, do not use numpy in randomSplit, no performance gain
f5fdf63 [Davies Liu] fix bug with int in weights
4dfa2cd [Davies Liu] refactor
f866bcf [Davies Liu] remove unneeded change
c7a2007 [Davies Liu] switch to python implementation
95a48ac [Davies Liu] Merge branch 'master' of github.com:apache/spark into randomSplit
0d9b256 [Davies Liu] refactor
1715ee3 [Davies Liu] address comments
41fce54 [Davies Liu] randomSplit()
Diffstat (limited to 'python/pyspark/rddsampler.py')
-rw-r--r-- | python/pyspark/rddsampler.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index f5c3cfd259..558dcfd12d 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -115,6 +115,20 @@ class RDDSampler(RDDSamplerBase): yield obj +class RDDRangeSampler(RDDSamplerBase): + + def __init__(self, lowerBound, upperBound, seed=None): + RDDSamplerBase.__init__(self, False, seed) + self._use_numpy = False # no performance gain from numpy + self._lowerBound = lowerBound + self._upperBound = upperBound + + def func(self, split, iterator): + for obj in iterator: + if self._lowerBound <= self.getUniformSample(split) < self._upperBound: + yield obj + + class RDDStratifiedSampler(RDDSamplerBase): def __init__(self, withReplacement, fractions, seed=None): |