aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f11aaf001c..63cc87e0c4 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1976,6 +1976,26 @@ class NumPyTests(PySparkTestCase):
self.assertSequenceEqual([3.0, 3.0], s.max().tolist())
self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist())
+ stats_dict = s.asDict()
+ self.assertEqual(3, stats_dict['count'])
+ self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist())
+ self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist())
+ self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist())
+
+ stats_sample_dict = s.asDict(sample=True)
+ self.assertEqual(3, stats_dict['count'])
+ self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist())
+ self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist())
+ self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist())
+ self.assertSequenceEqual(
+ [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist())
+ self.assertSequenceEqual(
+ [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist())
+
if __name__ == "__main__":
if not _have_scipy: