From a511c5379ee156f08624e380b240af7d961a60f7 Mon Sep 17 00:00:00 2001 From: Andre Schumacher Date: Fri, 23 Aug 2013 11:16:44 -0700 Subject: RDD sample() and takeSample() prototypes for PySpark --- python/pyspark/rdd.py | 62 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 7 deletions(-) (limited to 'python/pyspark/rdd.py') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1e9b3bb5c0..8394fe6a31 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -21,6 +21,7 @@ from collections import defaultdict from itertools import chain, ifilter, imap, product import operator import os +import sys import shlex from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile @@ -32,6 +33,7 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter +from pyspark.rddsampler import RDDSampler from py4j.java_collections import ListConverter, MapConverter @@ -165,14 +167,60 @@ class RDD(object): .reduceByKey(lambda x, _: x) \ .map(lambda (x, _): x) - # TODO: sampling needs to be re-implemented due to Batch - #def sample(self, withReplacement, fraction, seed): - # jrdd = self._jrdd.sample(withReplacement, fraction, seed) - # return RDD(jrdd, self.ctx) + def sample(self, withReplacement, fraction, seed): + """ + Return a sampled subset of this RDD (relies on numpy and falls back + on default random generator if numpy is unavailable). + + >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP + [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] + """ + return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True) + + # this is ported from scala/spark/RDD.scala + def takeSample(self, withReplacement, num, seed): + """ + Return a fixed-size sampled subset of this RDD (currently requires numpy). + + >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP + [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] + """ - #def takeSample(self, withReplacement, num, seed): - # vals = self._jrdd.takeSample(withReplacement, num, seed) - # return [load_pickle(bytes(x)) for x in vals] + fraction = 0.0 + total = 0 + multiplier = 3.0 + initialCount = self.count() + maxSelected = 0 + + if (num < 0): + raise ValueError + + if initialCount > sys.maxint - 1: + maxSelected = sys.maxint - 1 + else: + maxSelected = initialCount + + if num > initialCount and not withReplacement: + total = maxSelected + fraction = multiplier * (maxSelected + 1) / initialCount + else: + fraction = multiplier * (num + 1) / initialCount + total = num + + samples = self.sample(withReplacement, fraction, seed).collect() + + # If the first sample didn't turn out large enough, keep trying to take samples; + # this shouldn't happen often because we use a big multiplier for their initial size. + # See: scala/spark/RDD.scala + while len(samples) < total: + if seed > sys.maxint - 2: + seed = -1 + seed += 1 + samples = self.sample(withReplacement, fraction, seed).collect() + + sampler = RDDSampler(withReplacement, fraction, seed+1) + sampler.shuffle(samples) + return samples[0:total] def union(self, other): """ -- cgit v1.2.3