diff options
author | Doris Xin <doris.s.xin@gmail.com> | 2014-07-24 23:42:08 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-07-24 23:42:08 -0700 |
commit | 2f75a4a30e1a3fdf384475b9660c6c43f093f68c (patch) | |
tree | eb50e720cae6842bcb242d030adb27fba92c3f62 /python/pyspark/rdd.py | |
parent | 14174abd421318e71c16edd24224fd5094bdfed4 (diff) | |
download | spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.gz spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.bz2 spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.zip |
[SPARK-2656] Python version of stratified sampling
exact sample size not supported for now.
Author: Doris Xin <doris.s.xin@gmail.com>
Closes #1554 from dorx/pystratified and squashes the following commits:
4ba927a [Doris Xin] use rel diff (+- 50%) instead of abs diff (+- 50)
bdc3f8b [Doris Xin] updated unit to check sample holistically
7713c7b [Doris Xin] Python version of stratified sampling
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r-- | python/pyspark/rdd.py | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7ad6108261..113a082e16 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,7 +39,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter -from pyspark.rddsampler import RDDSampler +from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ @@ -411,7 +411,7 @@ class RDD(object): >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ - assert fraction >= 0.0, "Invalid fraction value: %s" % fraction + assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) # this is ported from scala/spark/RDD.scala @@ -1456,6 +1456,27 @@ class RDD(object): """ return python_cogroup((self, other), numPartitions) + def sampleByKey(self, withReplacement, fractions, seed=None): + """ + Return a subset of this RDD sampled by key (via stratified sampling). + Create a sample of this RDD using variable sampling rates for + different keys as specified by fractions, a key to sampling rate map. + + >>> fractions = {"a": 0.2, "b": 0.1} + >>> rdd = sc.parallelize(fractions.keys()).cartesian(sc.parallelize(range(0, 1000))) + >>> sample = dict(rdd.sampleByKey(False, fractions, 2).groupByKey().collect()) + >>> 100 < len(sample["a"]) < 300 and 50 < len(sample["b"]) < 150 + True + >>> max(sample["a"]) <= 999 and min(sample["a"]) >= 0 + True + >>> max(sample["b"]) <= 999 and min(sample["b"]) >= 0 + True + """ + for fraction in fractions.values(): + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + return self.mapPartitionsWithIndex( \ + RDDStratifiedSampler(withReplacement, fractions, seed).func, True) + def subtractByKey(self, other, numPartitions=None): """ Return each (key, value) pair in C{self} that has no pair with matching |