aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/statcounter.py21
-rw-r--r--python/pyspark/tests.py24
2 files changed, 37 insertions, 8 deletions
diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py
index e287bd3da1..1e597d64e0 100644
--- a/python/pyspark/statcounter.py
+++ b/python/pyspark/statcounter.py
@@ -20,6 +20,13 @@
import copy
import math
+try:
+ from numpy import maximum, minimum, sqrt
+except ImportError:
+ maximum = max
+ minimum = min
+ sqrt = math.sqrt
+
class StatCounter(object):
@@ -39,10 +46,8 @@ class StatCounter(object):
self.n += 1
self.mu += delta / self.n
self.m2 += delta * (value - self.mu)
- if self.maxValue < value:
- self.maxValue = value
- if self.minValue > value:
- self.minValue = value
+ self.maxValue = maximum(self.maxValue, value)
+ self.minValue = minimum(self.minValue, value)
return self
@@ -70,8 +75,8 @@ class StatCounter(object):
else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
- self.maxValue = max(self.maxValue, other.maxValue)
- self.minValue = min(self.minValue, other.minValue)
+ self.maxValue = maximum(self.maxValue, other.maxValue)
+ self.minValue = minimum(self.minValue, other.minValue)
self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n
@@ -115,14 +120,14 @@ class StatCounter(object):
# Return the standard deviation of the values.
def stdev(self):
- return math.sqrt(self.variance())
+ return sqrt(self.variance())
#
# Return the sample standard deviation of the values, which corrects for bias in estimating the
# variance by dividing by N-1 instead of N.
#
def sampleStdev(self):
- return math.sqrt(self.sampleVariance())
+ return sqrt(self.sampleVariance())
def __repr__(self):
return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c29deb9574..16fb5a9256 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -38,12 +38,19 @@ from pyspark.serializers import read_int
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
+_have_numpy = False
try:
import scipy.sparse
_have_scipy = True
except:
# No SciPy, but that's okay, we'll skip those tests
pass
+try:
+ import numpy as np
+ _have_numpy = True
+except:
+ # No NumPy, but that's okay, we'll skip those tests
+ pass
SPARK_HOME = os.environ["SPARK_HOME"]
@@ -914,9 +921,26 @@ class SciPyTests(PySparkTestCase):
self.assertEqual(expected, observed)
+@unittest.skipIf(not _have_numpy, "NumPy not installed")
+class NumPyTests(PySparkTestCase):
+ """General PySpark tests that depend on numpy """
+
+ def test_statcounter_array(self):
+ x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])])
+ s = x.stats()
+ self.assertSequenceEqual([2.0,2.0], s.mean().tolist())
+ self.assertSequenceEqual([1.0,1.0], s.min().tolist())
+ self.assertSequenceEqual([3.0,3.0], s.max().tolist())
+ self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist())
+
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
+ if not _have_numpy:
+ print "NOTE: Skipping NumPy tests as it does not seem to be installed"
unittest.main()
if not _have_scipy:
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
+ if not _have_numpy:
+ print "NOTE: NumPy tests were skipped as it does not seem to be installed"