aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorErik Shilts <erik.shilts@opower.com>2015-09-29 13:38:15 -0700
committerDavies Liu <davies.liu@gmail.com>2015-09-29 13:38:15 -0700
commit7d399c9daa6769ab234890c551e1b3456e0e6e85 (patch)
tree7248de1bb7f92111fd7cbfb22ff9ffb261bd2758 /python
parentab41864f91713b450695babd5c1424622cb57a54 (diff)
downloadspark-7d399c9daa6769ab234890c551e1b3456e0e6e85.tar.gz
spark-7d399c9daa6769ab234890c551e1b3456e0e6e85.tar.bz2
spark-7d399c9daa6769ab234890c551e1b3456e0e6e85.zip
[SPARK-6919] [PYSPARK] Add asDict method to StatCounter
Add method to easily convert a StatCounter instance into a Python dict https://issues.apache.org/jira/browse/SPARK-6919 Note: This is my original work and the existing Spark license applies. Author: Erik Shilts <erik.shilts@opower.com> Closes #5516 from eshilts/statcounter-asdict.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/statcounter.py22
-rw-r--r--python/pyspark/tests.py20
2 files changed, 42 insertions, 0 deletions
diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py
index 0fee3b2096..03ea0b6d33 100644
--- a/python/pyspark/statcounter.py
+++ b/python/pyspark/statcounter.py
@@ -131,6 +131,28 @@ class StatCounter(object):
def sampleStdev(self):
return sqrt(self.sampleVariance())
+ def asDict(self, sample=False):
+ """Returns the :class:`StatCounter` members as a ``dict``.
+
+ >>> sc.parallelize([1., 2., 3., 4.]).stats().asDict()
+ {'count': 4L,
+ 'max': 4.0,
+ 'mean': 2.5,
+ 'min': 1.0,
+ 'stdev': 1.2909944487358056,
+ 'sum': 10.0,
+ 'variance': 1.6666666666666667}
+ """
+ return {
+ 'count': self.count(),
+ 'mean': self.mean(),
+ 'sum': self.sum(),
+ 'min': self.min(),
+ 'max': self.max(),
+ 'stdev': self.stdev() if sample else self.sampleStdev(),
+ 'variance': self.variance() if sample else self.sampleVariance()
+ }
+
def __repr__(self):
return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
(self.count(), self.mean(), self.stdev(), self.max(), self.min()))
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f11aaf001c..63cc87e0c4 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1976,6 +1976,26 @@ class NumPyTests(PySparkTestCase):
self.assertSequenceEqual([3.0, 3.0], s.max().tolist())
self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist())
+ stats_dict = s.asDict()
+ self.assertEqual(3, stats_dict['count'])
+ self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist())
+ self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist())
+ self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist())
+
+ stats_sample_dict = s.asDict(sample=True)
+ self.assertEqual(3, stats_dict['count'])
+ self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist())
+ self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist())
+ self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist())
+ self.assertSequenceEqual(
+ [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist())
+ self.assertSequenceEqual(
+ [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist())
+
if __name__ == "__main__":
if not _have_scipy: