aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rddsampler.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-18 16:37:35 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-18 16:37:35 -0800
commit7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf (patch)
tree1c79469770ebd75dc3eea0fd165bc4beede27ab8 /python/pyspark/rddsampler.py
parentbb46046154a438df4db30a0e1fd557bd3399ee7b (diff)
downloadspark-7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf.tar.gz
spark-7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf.tar.bz2
spark-7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf.zip
[SPARK-4327] [PySpark] Python API for RDD.randomSplit()
``` pyspark.RDD.randomSplit(self, weights, seed=None) Randomly splits this RDD with the provided weights. :param weights: weights for splits, will be normalized if they don't sum to 1 :param seed: random seed :return: split RDDs in an list >>> rdd = sc.parallelize(range(10), 1) >>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11) >>> rdd1.collect() [3, 6] >>> rdd2.collect() [0, 5, 7] >>> rdd3.collect() [1, 2, 4, 8, 9] ``` Author: Davies Liu <davies@databricks.com> Closes #3193 from davies/randomSplit and squashes the following commits: 78bf997 [Davies Liu] fix tests, do not use numpy in randomSplit, no performance gain f5fdf63 [Davies Liu] fix bug with int in weights 4dfa2cd [Davies Liu] refactor f866bcf [Davies Liu] remove unneeded change c7a2007 [Davies Liu] switch to python implementation 95a48ac [Davies Liu] Merge branch 'master' of github.com:apache/spark into randomSplit 0d9b256 [Davies Liu] refactor 1715ee3 [Davies Liu] address comments 41fce54 [Davies Liu] randomSplit()
Diffstat (limited to 'python/pyspark/rddsampler.py')
-rw-r--r--python/pyspark/rddsampler.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index f5c3cfd259..558dcfd12d 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -115,6 +115,20 @@ class RDDSampler(RDDSamplerBase):
yield obj
+class RDDRangeSampler(RDDSamplerBase):
+
+ def __init__(self, lowerBound, upperBound, seed=None):
+ RDDSamplerBase.__init__(self, False, seed)
+ self._use_numpy = False # no performance gain from numpy
+ self._lowerBound = lowerBound
+ self._upperBound = upperBound
+
+ def func(self, split, iterator):
+ for obj in iterator:
+ if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
+ yield obj
+
+
class RDDStratifiedSampler(RDDSamplerBase):
def __init__(self, withReplacement, fractions, seed=None):