aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-07-24 23:42:08 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-24 23:42:08 -0700
commit2f75a4a30e1a3fdf384475b9660c6c43f093f68c (patch)
treeeb50e720cae6842bcb242d030adb27fba92c3f62 /python
parent14174abd421318e71c16edd24224fd5094bdfed4 (diff)
downloadspark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.gz
spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.tar.bz2
spark-2f75a4a30e1a3fdf384475b9660c6c43f093f68c.zip
[SPARK-2656] Python version of stratified sampling
exact sample size not supported for now. Author: Doris Xin <doris.s.xin@gmail.com> Closes #1554 from dorx/pystratified and squashes the following commits: 4ba927a [Doris Xin] use rel diff (+- 50%) instead of abs diff (+- 50) bdc3f8b [Doris Xin] updated unit to check sample holistically 7713c7b [Doris Xin] Python version of stratified sampling
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py25
-rw-r--r--python/pyspark/rddsampler.py30
2 files changed, 50 insertions, 5 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7ad6108261..113a082e16 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,7 +39,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler
+from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -411,7 +411,7 @@ class RDD(object):
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
"""
- assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
# this is ported from scala/spark/RDD.scala
@@ -1456,6 +1456,27 @@ class RDD(object):
"""
return python_cogroup((self, other), numPartitions)
+ def sampleByKey(self, withReplacement, fractions, seed=None):
+ """
+ Return a subset of this RDD sampled by key (via stratified sampling).
+ Create a sample of this RDD using variable sampling rates for
+ different keys as specified by fractions, a key to sampling rate map.
+
+ >>> fractions = {"a": 0.2, "b": 0.1}
+ >>> rdd = sc.parallelize(fractions.keys()).cartesian(sc.parallelize(range(0, 1000)))
+ >>> sample = dict(rdd.sampleByKey(False, fractions, 2).groupByKey().collect())
+ >>> 100 < len(sample["a"]) < 300 and 50 < len(sample["b"]) < 150
+ True
+ >>> max(sample["a"]) <= 999 and min(sample["a"]) >= 0
+ True
+ >>> max(sample["b"]) <= 999 and min(sample["b"]) >= 0
+ True
+ """
+ for fraction in fractions.values():
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ return self.mapPartitionsWithIndex( \
+ RDDStratifiedSampler(withReplacement, fractions, seed).func, True)
+
def subtractByKey(self, other, numPartitions=None):
"""
Return each (key, value) pair in C{self} that has no pair with matching
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 7ff1c316c7..2df000fdb0 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -19,8 +19,8 @@ import sys
import random
-class RDDSampler(object):
- def __init__(self, withReplacement, fraction, seed=None):
+class RDDSamplerBase(object):
+ def __init__(self, withReplacement, seed=None):
try:
import numpy
self._use_numpy = True
@@ -32,7 +32,6 @@ class RDDSampler(object):
self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
- self._fraction = fraction
self._random = None
self._split = None
self._rand_initialized = False
@@ -94,6 +93,12 @@ class RDDSampler(object):
else:
self._random.shuffle(vals, self._random.random)
+
+class RDDSampler(RDDSamplerBase):
+ def __init__(self, withReplacement, fraction, seed=None):
+ RDDSamplerBase.__init__(self, withReplacement, seed)
+ self._fraction = fraction
+
def func(self, split, iterator):
if self._withReplacement:
for obj in iterator:
@@ -107,3 +112,22 @@ class RDDSampler(object):
for obj in iterator:
if self.getUniformSample(split) <= self._fraction:
yield obj
+
+class RDDStratifiedSampler(RDDSamplerBase):
+ def __init__(self, withReplacement, fractions, seed=None):
+ RDDSamplerBase.__init__(self, withReplacement, seed)
+ self._fractions = fractions
+
+ def func(self, split, iterator):
+ if self._withReplacement:
+ for key, val in iterator:
+ # For large datasets, the expected number of occurrences of each element in
+ # a sample with replacement is Poisson(frac). We use that to get a count for
+ # each element.
+ count = self.getPoissonSample(split, mean=self._fractions[key])
+ for _ in range(0, count):
+ yield key, val
+ else:
+ for key, val in iterator:
+ if self.getUniformSample(split) <= self._fractions[key]:
+ yield key, val