aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py191
1 files changed, 171 insertions, 20 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c6a6b24c5a..914118ccdd 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -21,6 +21,7 @@ from collections import defaultdict
from itertools import chain, ifilter, imap, product
import operator
import os
+import sys
import shlex
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
@@ -31,6 +32,8 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
read_from_pickle_file
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
+from pyspark.statcounter import StatCounter
+from pyspark.rddsampler import RDDSampler
from py4j.java_collections import ListConverter, MapConverter
@@ -160,18 +163,64 @@ class RDD(object):
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
[1, 2, 3]
"""
- return self.map(lambda x: (x, "")) \
+ return self.map(lambda x: (x, None)) \
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
- # TODO: sampling needs to be re-implemented due to Batch
- #def sample(self, withReplacement, fraction, seed):
- # jrdd = self._jrdd.sample(withReplacement, fraction, seed)
- # return RDD(jrdd, self.ctx)
+ def sample(self, withReplacement, fraction, seed):
+ """
+ Return a sampled subset of this RDD (relies on numpy and falls back
+ on default random generator if numpy is unavailable).
+
+ >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
+ [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
+ """
+ return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True)
+
+ # this is ported from scala/spark/RDD.scala
+ def takeSample(self, withReplacement, num, seed):
+ """
+ Return a fixed-size sampled subset of this RDD (currently requires numpy).
+
+ >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
+ [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
+ """
+
+ fraction = 0.0
+ total = 0
+ multiplier = 3.0
+ initialCount = self.count()
+ maxSelected = 0
+
+ if (num < 0):
+ raise ValueError
- #def takeSample(self, withReplacement, num, seed):
- # vals = self._jrdd.takeSample(withReplacement, num, seed)
- # return [load_pickle(bytes(x)) for x in vals]
+ if initialCount > sys.maxint - 1:
+ maxSelected = sys.maxint - 1
+ else:
+ maxSelected = initialCount
+
+ if num > initialCount and not withReplacement:
+ total = maxSelected
+ fraction = multiplier * (maxSelected + 1) / initialCount
+ else:
+ fraction = multiplier * (num + 1) / initialCount
+ total = num
+
+ samples = self.sample(withReplacement, fraction, seed).collect()
+
+ # If the first sample didn't turn out large enough, keep trying to take samples;
+ # this shouldn't happen often because we use a big multiplier for their initial size.
+ # See: scala/spark/RDD.scala
+ while len(samples) < total:
+ if seed > sys.maxint - 2:
+ seed = -1
+ seed += 1
+ samples = self.sample(withReplacement, fraction, seed).collect()
+
+ sampler = RDDSampler(withReplacement, fraction, seed+1)
+ sampler.shuffle(samples)
+ return samples[0:total]
def union(self, other):
"""
@@ -267,7 +316,11 @@ class RDD(object):
>>> def f(x): print x
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
- self.map(f).collect() # Force evaluation
+ def processPartition(iterator):
+ for x in iterator:
+ f(x)
+ yield None
+ self.mapPartitions(processPartition).collect() # Force evaluation
def collect(self):
"""
@@ -353,6 +406,63 @@ class RDD(object):
3
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
+
+ def stats(self):
+ """
+ Return a L{StatCounter} object that captures the mean, variance
+ and count of the RDD's elements in one operation.
+ """
+ def redFunc(left_counter, right_counter):
+ return left_counter.mergeStats(right_counter)
+
+ return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc)
+
+ def mean(self):
+ """
+ Compute the mean of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).mean()
+ 2.0
+ """
+ return self.stats().mean()
+
+ def variance(self):
+ """
+ Compute the variance of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).variance()
+ 0.666...
+ """
+ return self.stats().variance()
+
+ def stdev(self):
+ """
+ Compute the standard deviation of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).stdev()
+ 0.816...
+ """
+ return self.stats().stdev()
+
+ def sampleStdev(self):
+ """
+ Compute the sample standard deviation of this RDD's elements (which corrects for bias in
+ estimating the standard deviation by dividing by N-1 instead of N).
+
+ >>> sc.parallelize([1, 2, 3]).sampleStdev()
+ 1.0
+ """
+ return self.stats().sampleStdev()
+
+ def sampleVariance(self):
+ """
+ Compute the sample variance of this RDD's elements (which corrects for bias in
+ estimating the variance by dividing by N-1 instead of N).
+
+ >>> sc.parallelize([1, 2, 3]).sampleVariance()
+ 1.0
+ """
+ return self.stats().sampleVariance()
def countByValue(self):
"""
@@ -386,13 +496,16 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
+ def takeUpToNum(iterator):
+ taken = 0
+ while taken < num:
+ yield next(iterator)
+ taken += 1
+ # Take only up to num elements from each partition we try
+ mapped = self.mapPartitions(takeUpToNum)
items = []
- for partition in range(self._jrdd.splits().size()):
- iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
- # Each item in the iterator is a string, Python object, batch of
- # Python objects. Regardless, it is sufficient to take `num`
- # of these objects in order to collect `num` Python objects:
- iterator = iterator.take(num)
+ for partition in range(mapped._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break
@@ -689,6 +802,43 @@ class RDD(object):
"""
return python_cogroup(self, other, numPartitions)
+ def subtractByKey(self, other, numPartitions=None):
+ """
+ Return each (key, value) pair in C{self} that has no pair with matching key
+ in C{other}.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)])
+ >>> y = sc.parallelize([("a", 3), ("c", None)])
+ >>> sorted(x.subtractByKey(y).collect())
+ [('b', 4), ('b', 5)]
+ """
+ filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0
+ map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]]
+ return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
+
+ def subtract(self, other, numPartitions=None):
+ """
+ Return each value in C{self} that is not contained in C{other}.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 3)])
+ >>> y = sc.parallelize([("a", 3), ("c", None)])
+ >>> sorted(x.subtract(y).collect())
+ [('a', 1), ('b', 4), ('b', 5)]
+ """
+ rdd = other.map(lambda x: (x, True)) # note: here 'True' is just a placeholder
+ return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) # note: here 'True' is just a placeholder
+
+ def keyBy(self, f):
+ """
+ Creates tuples of the elements in this RDD by applying C{f}.
+
+ >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x)
+ >>> y = sc.parallelize(zip(range(0,5), range(0,5)))
+ >>> sorted(x.cogroup(y).collect())
+ [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))]
+ """
+ return self.map(lambda x: (f(x), x))
+
# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
# keys in the pairs. This could be an expensive operation, since those
@@ -749,11 +899,12 @@ class PipelinedRDD(RDD):
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
- env = copy.copy(self.ctx.environment)
- env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
- env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
+ env = MapConverter().convert(self.ctx.environment,
+ self.ctx._gateway._gateway_client)
+ includes = ListConverter().convert(self.ctx._python_includes,
+ self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
+ pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
@@ -769,7 +920,7 @@ def _test():
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- (failure_count, test_count) = doctest.testmod(globs=globs)
+ (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)