From 70d9e3871f852ec9e8bfaa436bc02bc22fc62dfd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 18 Nov 2014 16:37:35 -0800 Subject: [SPARK-4327] [PySpark] Python API for RDD.randomSplit() ``` pyspark.RDD.randomSplit(self, weights, seed=None) Randomly splits this RDD with the provided weights. :param weights: weights for splits, will be normalized if they don't sum to 1 :param seed: random seed :return: split RDDs in an list >>> rdd = sc.parallelize(range(10), 1) >>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11) >>> rdd1.collect() [3, 6] >>> rdd2.collect() [0, 5, 7] >>> rdd3.collect() [1, 2, 4, 8, 9] ``` Author: Davies Liu Closes #3193 from davies/randomSplit and squashes the following commits: 78bf997 [Davies Liu] fix tests, do not use numpy in randomSplit, no performance gain f5fdf63 [Davies Liu] fix bug with int in weights 4dfa2cd [Davies Liu] refactor f866bcf [Davies Liu] remove unneeded change c7a2007 [Davies Liu] switch to python implementation 95a48ac [Davies Liu] Merge branch 'master' of github.com:apache/spark into randomSplit 0d9b256 [Davies Liu] refactor 1715ee3 [Davies Liu] address comments 41fce54 [Davies Liu] randomSplit() (cherry picked from commit 7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf) Signed-off-by: Xiangrui Meng --- python/pyspark/rdd.py | 30 +++++++++++++++++++++++++++--- python/pyspark/rddsampler.py | 14 ++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 08d0474026..50535d2711 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -28,7 +28,7 @@ from threading import Thread import warnings import heapq import bisect -from random import Random +import random from math import sqrt, log, isinf, isnan from pyspark.accumulators import PStatsParam @@ -38,7 +38,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter -from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler +from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ @@ -316,6 +316,30 @@ class RDD(object): assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) + def randomSplit(self, weights, seed=None): + """ + Randomly splits this RDD with the provided weights. + + :param weights: weights for splits, will be normalized if they don't sum to 1 + :param seed: random seed + :return: split RDDs in a list + + >>> rdd = sc.parallelize(range(5), 1) + >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17) + >>> rdd1.collect() + [1, 3] + >>> rdd2.collect() + [0, 2, 4] + """ + s = float(sum(weights)) + cweights = [0.0] + for w in weights: + cweights.append(cweights[-1] + w / s) + if seed is None: + seed = random.randint(0, 2 ** 32 - 1) + return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True) + for lb, ub in zip(cweights, cweights[1:])] + # this is ported from scala/spark/RDD.scala def takeSample(self, withReplacement, num, seed=None): """ @@ -341,7 +365,7 @@ class RDD(object): if initialCount == 0: return [] - rand = Random(seed) + rand = random.Random(seed) if (not withReplacement) and num >= initialCount: # shuffle current RDD and return diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index f5c3cfd259..558dcfd12d 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -115,6 +115,20 @@ class RDDSampler(RDDSamplerBase): yield obj +class RDDRangeSampler(RDDSamplerBase): + + def __init__(self, lowerBound, upperBound, seed=None): + RDDSamplerBase.__init__(self, False, seed) + self._use_numpy = False # no performance gain from numpy + self._lowerBound = lowerBound + self._upperBound = upperBound + + def func(self, split, iterator): + for obj in iterator: + if self._lowerBound <= self.getUniformSample(split) < self._upperBound: + yield obj + + class RDDStratifiedSampler(RDDSamplerBase): def __init__(self, withReplacement, fractions, seed=None): -- cgit v1.2.3