aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorAndre Schumacher <schumach@icsi.berkeley.edu>2013-08-20 13:22:06 -0700
committerAndre Schumacher <schumach@icsi.berkeley.edu>2013-08-21 17:05:58 -0700
commit76077bf9f4b726699ba9e59cdfa9c4361df4ea92 (patch)
tree8e8130d9d825ae9d0216076cc62b0382a5daf677 /python/pyspark/rdd.py
parent53b1c30607a9b19e795fd5b6107dfefb83820282 (diff)
downloadspark-76077bf9f4b726699ba9e59cdfa9c4361df4ea92.tar.gz
spark-76077bf9f4b726699ba9e59cdfa9c4361df4ea92.tar.bz2
spark-76077bf9f4b726699ba9e59cdfa9c4361df4ea92.zip
Implementing SPARK-838: Add DoubleRDDFunctions methods to PySpark
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py60
1 files changed, 59 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 99f5967a8e..1e9b3bb5c0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -31,6 +31,7 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
read_from_pickle_file
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
+from pyspark.statcounter import StatCounter
from py4j.java_collections import ListConverter, MapConverter
@@ -357,6 +358,63 @@ class RDD(object):
3
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
+
+ def stats(self):
+ """
+ Return a L{StatCounter} object that captures the mean, variance
+ and count of the RDD's elements in one operation.
+ """
+ def redFunc(left_counter, right_counter):
+ return left_counter.mergeStats(right_counter)
+
+ return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc)
+
+ def mean(self):
+ """
+ Compute the mean of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).mean()
+ 2.0
+ """
+ return self.stats().mean()
+
+ def variance(self):
+ """
+ Compute the variance of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).variance()
+ 0.666...
+ """
+ return self.stats().variance()
+
+ def stdev(self):
+ """
+ Compute the standard deviation of this RDD's elements.
+
+ >>> sc.parallelize([1, 2, 3]).stdev()
+ 0.816...
+ """
+ return self.stats().stdev()
+
+ def sampleStdev(self):
+ """
+ Compute the sample standard deviation of this RDD's elements (which corrects for bias in
+ estimating the standard deviation by dividing by N-1 instead of N).
+
+ >>> sc.parallelize([1, 2, 3]).sampleStdev()
+ 1.0
+ """
+ return self.stats().sampleStdev()
+
+ def sampleVariance(self):
+ """
+ Compute the sample variance of this RDD's elements (which corrects for bias in
+ estimating the variance by dividing by N-1 instead of N).
+
+ >>> sc.parallelize([1, 2, 3]).sampleVariance()
+ 1.0
+ """
+ return self.stats().sampleVariance()
def countByValue(self):
"""
@@ -777,7 +835,7 @@ def _test():
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- (failure_count, test_count) = doctest.testmod(globs=globs)
+ (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)