aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorArun Ramakrishnan <smartnut007@gmail.com>2014-04-24 17:27:16 -0700
committerMatei Zaharia <matei@databricks.com>2014-04-24 17:27:16 -0700
commit35e3d199f04fba3230625002a458d43b9578b2e8 (patch)
tree7e301a1585e3dc45cd1a42b8ce567b0aada57b4f /python/pyspark/rdd.py
parentf99af8529b6969986f0c3e03f6ff9b7bb9d53ece (diff)
downloadspark-35e3d199f04fba3230625002a458d43b9578b2e8.tar.gz
spark-35e3d199f04fba3230625002a458d43b9578b2e8.tar.bz2
spark-35e3d199f04fba3230625002a458d43b9578b2e8.zip
SPARK-1438 RDD.sample() make seed param optional
copying form previous pull request https://github.com/apache/spark/pull/462 Its probably better to let the underlying language implementation take care of the default . This was easier to do with python as the default value for seed in random and numpy random is None. In Scala/Java side it might mean propagating an Option or null(oh no!) down the chain until where the Random is constructed. But, looks like the convention in some other methods was to use System.nanoTime. So, followed that convention. Conflict with overloaded method in sql.SchemaRDD.sample which also defines default params. sample(fraction, withReplacement=false, seed=math.random) Scala does not allow more than one overloaded to have default params. I believe the author intended to override the RDD.sample method and not overload it. So, changed it. If backward compatible is important, 3 new method can be introduced (without default params) like this sample(fraction) sample(fraction, withReplacement) sample(fraction, withReplacement, seed) Added some tests for the scala RDD takeSample method. Author: Arun Ramakrishnan <smartnut007@gmail.com> This patch had conflicts when merged, resolved by Committer: Matei Zaharia <matei@databricks.com> Closes #477 from smartnut007/master and squashes the following commits: 07bb06e [Arun Ramakrishnan] SPARK-1438 fixing more space formatting issues b9ebfe2 [Arun Ramakrishnan] SPARK-1438 removing redundant import of random in python rddsampler 8d05b1a [Arun Ramakrishnan] SPARK-1438 RDD . Replace System.nanoTime with a Random generated number. python: use a separate instance of Random instead of seeding language api global Random instance. 69619c6 [Arun Ramakrishnan] SPARK-1438 fix spacing issue 0c247db [Arun Ramakrishnan] SPARK-1438 RDD language apis to support optional seed in RDD methods sample/takeSample
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 91fc7e637e..d73ab7006e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -30,6 +30,7 @@ from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
import heapq
+from random import Random
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -332,7 +333,7 @@ class RDD(object):
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
- def sample(self, withReplacement, fraction, seed):
+ def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this RDD (relies on numpy and falls back
on default random generator if numpy is unavailable).
@@ -344,7 +345,7 @@ class RDD(object):
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
# this is ported from scala/spark/RDD.scala
- def takeSample(self, withReplacement, num, seed):
+ def takeSample(self, withReplacement, num, seed=None):
"""
Return a fixed-size sampled subset of this RDD (currently requires numpy).
@@ -381,13 +382,11 @@ class RDD(object):
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
+ rand = Random(seed)
while len(samples) < total:
- if seed > sys.maxint - 2:
- seed = -1
- seed += 1
- samples = self.sample(withReplacement, fraction, seed).collect()
+ samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()
- sampler = RDDSampler(withReplacement, fraction, seed+1)
+ sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
sampler.shuffle(samples)
return samples[0:total]