aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py30
-rw-r--r--python/pyspark/rddsampler.py14
2 files changed, 41 insertions, 3 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 08d0474026..50535d2711 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -28,7 +28,7 @@ from threading import Thread
import warnings
import heapq
import bisect
-from random import Random
+import random
from math import sqrt, log, isinf, isnan
from pyspark.accumulators import PStatsParam
@@ -38,7 +38,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
+from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -316,6 +316,30 @@ class RDD(object):
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
+ def randomSplit(self, weights, seed=None):
+ """
+ Randomly splits this RDD with the provided weights.
+
+ :param weights: weights for splits, will be normalized if they don't sum to 1
+ :param seed: random seed
+ :return: split RDDs in a list
+
+ >>> rdd = sc.parallelize(range(5), 1)
+ >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
+ >>> rdd1.collect()
+ [1, 3]
+ >>> rdd2.collect()
+ [0, 2, 4]
+ """
+ s = float(sum(weights))
+ cweights = [0.0]
+ for w in weights:
+ cweights.append(cweights[-1] + w / s)
+ if seed is None:
+ seed = random.randint(0, 2 ** 32 - 1)
+ return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
+ for lb, ub in zip(cweights, cweights[1:])]
+
# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed=None):
"""
@@ -341,7 +365,7 @@ class RDD(object):
if initialCount == 0:
return []
- rand = Random(seed)
+ rand = random.Random(seed)
if (not withReplacement) and num >= initialCount:
# shuffle current RDD and return
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index f5c3cfd259..558dcfd12d 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -115,6 +115,20 @@ class RDDSampler(RDDSamplerBase):
yield obj
+class RDDRangeSampler(RDDSamplerBase):
+
+ def __init__(self, lowerBound, upperBound, seed=None):
+ RDDSamplerBase.__init__(self, False, seed)
+ self._use_numpy = False # no performance gain from numpy
+ self._lowerBound = lowerBound
+ self._upperBound = upperBound
+
+ def func(self, split, iterator):
+ for obj in iterator:
+ if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
+ yield obj
+
+
class RDDStratifiedSampler(RDDSamplerBase):
def __init__(self, withReplacement, fractions, seed=None):