aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Freeman <the.freeman.lab@gmail.com>2014-08-01 22:33:25 -0700
committerJosh Rosen <joshrosen@apache.org>2014-08-01 22:33:25 -0700
commit4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace (patch)
treeece0def1b321943074f43a6670040c02711604e3
parentfda475987f3b8b37d563033b0e45706ce433824a (diff)
downloadspark-4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace.tar.gz
spark-4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace.tar.bz2
spark-4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace.zip
StatCounter on NumPy arrays [PYSPARK][SPARK-2012]
These changes allow StatCounters to work properly on NumPy arrays, to fix the issue reported here (https://issues.apache.org/jira/browse/SPARK-2012). If NumPy is installed, the NumPy functions ``maximum``, ``minimum``, and ``sqrt``, which work on arrays, are used to merge statistics. If not, we fall back on scalar operators, so it will work on arrays with NumPy, but will also work without NumPy. New unit tests added, along with a check for NumPy in the tests. Author: Jeremy Freeman <the.freeman.lab@gmail.com> Closes #1725 from freeman-lab/numpy-max-statcounter and squashes the following commits: fe973b1 [Jeremy Freeman] Avoid duplicate array import in tests 7f0e397 [Jeremy Freeman] Refactored check for numpy 8e764dd [Jeremy Freeman] Explicit numpy imports 875414c [Jeremy Freeman] Fixed indents 1c8a832 [Jeremy Freeman] Unit tests for StatCounter with NumPy arrays 176a127 [Jeremy Freeman] Use numpy arrays in StatCounter
-rw-r--r--python/pyspark/statcounter.py21
-rw-r--r--python/pyspark/tests.py24
2 files changed, 37 insertions, 8 deletions
diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py
index e287bd3da1..1e597d64e0 100644
--- a/python/pyspark/statcounter.py
+++ b/python/pyspark/statcounter.py
@@ -20,6 +20,13 @@
import copy
import math
+try:
+ from numpy import maximum, minimum, sqrt
+except ImportError:
+ maximum = max
+ minimum = min
+ sqrt = math.sqrt
+
class StatCounter(object):
@@ -39,10 +46,8 @@ class StatCounter(object):
self.n += 1
self.mu += delta / self.n
self.m2 += delta * (value - self.mu)
- if self.maxValue < value:
- self.maxValue = value
- if self.minValue > value:
- self.minValue = value
+ self.maxValue = maximum(self.maxValue, value)
+ self.minValue = minimum(self.minValue, value)
return self
@@ -70,8 +75,8 @@ class StatCounter(object):
else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
- self.maxValue = max(self.maxValue, other.maxValue)
- self.minValue = min(self.minValue, other.minValue)
+ self.maxValue = maximum(self.maxValue, other.maxValue)
+ self.minValue = minimum(self.minValue, other.minValue)
self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n
@@ -115,14 +120,14 @@ class StatCounter(object):
# Return the standard deviation of the values.
def stdev(self):
- return math.sqrt(self.variance())
+ return sqrt(self.variance())
#
# Return the sample standard deviation of the values, which corrects for bias in estimating the
# variance by dividing by N-1 instead of N.
#
def sampleStdev(self):
- return math.sqrt(self.sampleVariance())
+ return sqrt(self.sampleVariance())
def __repr__(self):
return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c29deb9574..16fb5a9256 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -38,12 +38,19 @@ from pyspark.serializers import read_int
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
+_have_numpy = False
try:
import scipy.sparse
_have_scipy = True
except:
# No SciPy, but that's okay, we'll skip those tests
pass
+try:
+ import numpy as np
+ _have_numpy = True
+except:
+ # No NumPy, but that's okay, we'll skip those tests
+ pass
SPARK_HOME = os.environ["SPARK_HOME"]
@@ -914,9 +921,26 @@ class SciPyTests(PySparkTestCase):
self.assertEqual(expected, observed)
+@unittest.skipIf(not _have_numpy, "NumPy not installed")
+class NumPyTests(PySparkTestCase):
+ """General PySpark tests that depend on numpy """
+
+ def test_statcounter_array(self):
+ x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])])
+ s = x.stats()
+ self.assertSequenceEqual([2.0,2.0], s.mean().tolist())
+ self.assertSequenceEqual([1.0,1.0], s.min().tolist())
+ self.assertSequenceEqual([3.0,3.0], s.max().tolist())
+ self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist())
+
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
+ if not _have_numpy:
+ print "NOTE: Skipping NumPy tests as it does not seem to be installed"
unittest.main()
if not _have_scipy:
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
+ if not _have_numpy:
+ print "NOTE: NumPy tests were skipped as it does not seem to be installed"