aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorAndre Schumacher <schumach@icsi.berkeley.edu>2013-08-23 11:16:44 -0700
committerAndre Schumacher <schumach@icsi.berkeley.edu>2013-08-28 16:46:13 -0700
commita511c5379ee156f08624e380b240af7d961a60f7 (patch)
treea25b0a5c7cdb2677e1b910625beccfa5e134ddd0 /python/pyspark/rdd.py
parenta9db1b7b6eb030feb7beee017d2eca402b73c67c (diff)
downloadspark-a511c5379ee156f08624e380b240af7d961a60f7.tar.gz
spark-a511c5379ee156f08624e380b240af7d961a60f7.tar.bz2
spark-a511c5379ee156f08624e380b240af7d961a60f7.zip
RDD sample() and takeSample() prototypes for PySpark
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py62
1 files changed, 55 insertions, 7 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1e9b3bb5c0..8394fe6a31 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -21,6 +21,7 @@ from collections import defaultdict
from itertools import chain, ifilter, imap, product
import operator
import os
+import sys
import shlex
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
@@ -32,6 +33,7 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
+from pyspark.rddsampler import RDDSampler
from py4j.java_collections import ListConverter, MapConverter
@@ -165,14 +167,60 @@ class RDD(object):
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
- # TODO: sampling needs to be re-implemented due to Batch
- #def sample(self, withReplacement, fraction, seed):
- # jrdd = self._jrdd.sample(withReplacement, fraction, seed)
- # return RDD(jrdd, self.ctx)
+ def sample(self, withReplacement, fraction, seed):
+ """
+ Return a sampled subset of this RDD (relies on numpy and falls back
+ on default random generator if numpy is unavailable).
+
+ >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
+ [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
+ """
+ return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True)
+
+ # this is ported from scala/spark/RDD.scala
+ def takeSample(self, withReplacement, num, seed):
+ """
+ Return a fixed-size sampled subset of this RDD (currently requires numpy).
+
+ >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
+ [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
+ """
- #def takeSample(self, withReplacement, num, seed):
- # vals = self._jrdd.takeSample(withReplacement, num, seed)
- # return [load_pickle(bytes(x)) for x in vals]
+ fraction = 0.0
+ total = 0
+ multiplier = 3.0
+ initialCount = self.count()
+ maxSelected = 0
+
+ if (num < 0):
+ raise ValueError
+
+ if initialCount > sys.maxint - 1:
+ maxSelected = sys.maxint - 1
+ else:
+ maxSelected = initialCount
+
+ if num > initialCount and not withReplacement:
+ total = maxSelected
+ fraction = multiplier * (maxSelected + 1) / initialCount
+ else:
+ fraction = multiplier * (num + 1) / initialCount
+ total = num
+
+ samples = self.sample(withReplacement, fraction, seed).collect()
+
+ # If the first sample didn't turn out large enough, keep trying to take samples;
+ # this shouldn't happen often because we use a big multiplier for their initial size.
+ # See: scala/spark/RDD.scala
+ while len(samples) < total:
+ if seed > sys.maxint - 2:
+ seed = -1
+ seed += 1
+ samples = self.sample(withReplacement, fraction, seed).collect()
+
+ sampler = RDDSampler(withReplacement, fraction, seed+1)
+ sampler.shuffle(samples)
+ return samples[0:total]
def union(self, other):
"""