aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-18 16:37:35 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-18 16:37:43 -0800
commit70d9e3871f852ec9e8bfaa436bc02bc22fc62dfd (patch)
tree189f02d8ed0cbe6959ad74997335dbe95d35ffb9 /python
parentbf76164f1090892544983f753d4b7b16903a6135 (diff)
downloadspark-70d9e3871f852ec9e8bfaa436bc02bc22fc62dfd.tar.gz
spark-70d9e3871f852ec9e8bfaa436bc02bc22fc62dfd.tar.bz2
spark-70d9e3871f852ec9e8bfaa436bc02bc22fc62dfd.zip
[SPARK-4327] [PySpark] Python API for RDD.randomSplit()
``` pyspark.RDD.randomSplit(self, weights, seed=None) Randomly splits this RDD with the provided weights. :param weights: weights for splits, will be normalized if they don't sum to 1 :param seed: random seed :return: split RDDs in an list >>> rdd = sc.parallelize(range(10), 1) >>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11) >>> rdd1.collect() [3, 6] >>> rdd2.collect() [0, 5, 7] >>> rdd3.collect() [1, 2, 4, 8, 9] ``` Author: Davies Liu <davies@databricks.com> Closes #3193 from davies/randomSplit and squashes the following commits: 78bf997 [Davies Liu] fix tests, do not use numpy in randomSplit, no performance gain f5fdf63 [Davies Liu] fix bug with int in weights 4dfa2cd [Davies Liu] refactor f866bcf [Davies Liu] remove unneeded change c7a2007 [Davies Liu] switch to python implementation 95a48ac [Davies Liu] Merge branch 'master' of github.com:apache/spark into randomSplit 0d9b256 [Davies Liu] refactor 1715ee3 [Davies Liu] address comments 41fce54 [Davies Liu] randomSplit() (cherry picked from commit 7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py30
-rw-r--r--python/pyspark/rddsampler.py14
2 files changed, 41 insertions, 3 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 08d0474026..50535d2711 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -28,7 +28,7 @@ from threading import Thread
import warnings
import heapq
import bisect
-from random import Random
+import random
from math import sqrt, log, isinf, isnan
from pyspark.accumulators import PStatsParam
@@ -38,7 +38,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
+from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -316,6 +316,30 @@ class RDD(object):
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
+ def randomSplit(self, weights, seed=None):
+ """
+ Randomly splits this RDD with the provided weights.
+
+ :param weights: weights for splits, will be normalized if they don't sum to 1
+ :param seed: random seed
+ :return: split RDDs in a list
+
+ >>> rdd = sc.parallelize(range(5), 1)
+ >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
+ >>> rdd1.collect()
+ [1, 3]
+ >>> rdd2.collect()
+ [0, 2, 4]
+ """
+ s = float(sum(weights))
+ cweights = [0.0]
+ for w in weights:
+ cweights.append(cweights[-1] + w / s)
+ if seed is None:
+ seed = random.randint(0, 2 ** 32 - 1)
+ return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
+ for lb, ub in zip(cweights, cweights[1:])]
+
# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed=None):
"""
@@ -341,7 +365,7 @@ class RDD(object):
if initialCount == 0:
return []
- rand = Random(seed)
+ rand = random.Random(seed)
if (not withReplacement) and num >= initialCount:
# shuffle current RDD and return
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index f5c3cfd259..558dcfd12d 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -115,6 +115,20 @@ class RDDSampler(RDDSamplerBase):
yield obj
+class RDDRangeSampler(RDDSamplerBase):
+
+ def __init__(self, lowerBound, upperBound, seed=None):
+ RDDSamplerBase.__init__(self, False, seed)
+ self._use_numpy = False # no performance gain from numpy
+ self._lowerBound = lowerBound
+ self._upperBound = upperBound
+
+ def func(self, split, iterator):
+ for obj in iterator:
+ if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
+ yield obj
+
+
class RDDStratifiedSampler(RDDSamplerBase):
def __init__(self, withReplacement, fractions, seed=None):