aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-07-24 23:42:08 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-24 23:42:08 -0700
commit2f75a4a30e1a3fdf384475b9660c6c43f093f68c (patch)
treeeb50e720cae6842bcb242d030adb27fba92c3f62 /python/pyspark/rdd.py
parent14174abd421318e71c16edd24224fd5094bdfed4 (diff)
downloadspark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.gz
spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.bz2
spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.zip
[SPARK-2656] Python version of stratified sampling
exact sample size not supported for now. Author: Doris Xin <doris.s.xin@gmail.com> Closes #1554 from dorx/pystratified and squashes the following commits: 4ba927a [Doris Xin] use rel diff (+- 50%) instead of abs diff (+- 50) bdc3f8b [Doris Xin] updated unit to check sample holistically 7713c7b [Doris Xin] Python version of stratified sampling
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7ad6108261..113a082e16 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,7 +39,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler
+from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -411,7 +411,7 @@ class RDD(object):
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
"""
- assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
# this is ported from scala/spark/RDD.scala
@@ -1456,6 +1456,27 @@ class RDD(object):
"""
return python_cogroup((self, other), numPartitions)
+ def sampleByKey(self, withReplacement, fractions, seed=None):
+ """
+ Return a subset of this RDD sampled by key (via stratified sampling).
+ Create a sample of this RDD using variable sampling rates for
+ different keys as specified by fractions, a key to sampling rate map.
+
+ >>> fractions = {"a": 0.2, "b": 0.1}
+ >>> rdd = sc.parallelize(fractions.keys()).cartesian(sc.parallelize(range(0, 1000)))
+ >>> sample = dict(rdd.sampleByKey(False, fractions, 2).groupByKey().collect())
+ >>> 100 < len(sample["a"]) < 300 and 50 < len(sample["b"]) < 150
+ True
+ >>> max(sample["a"]) <= 999 and min(sample["a"]) >= 0
+ True
+ >>> max(sample["b"]) <= 999 and min(sample["b"]) >= 0
+ True
+ """
+ for fraction in fractions.values():
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ return self.mapPartitionsWithIndex( \
+ RDDStratifiedSampler(withReplacement, fractions, seed).func, True)
+
def subtractByKey(self, other, numPartitions=None):
"""
Return each (key, value) pair in C{self} that has no pair with matching