diff options
author | Jeremy Freeman <the.freeman.lab@gmail.com> | 2014-08-01 22:33:25 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@apache.org> | 2014-08-01 22:33:25 -0700 |
commit | 4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace (patch) | |
tree | ece0def1b321943074f43a6670040c02711604e3 /python/pyspark/statcounter.py | |
parent | fda475987f3b8b37d563033b0e45706ce433824a (diff) | |
download | spark-4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace.tar.gz spark-4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace.tar.bz2 spark-4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace.zip |
StatCounter on NumPy arrays [PYSPARK][SPARK-2012]
These changes allow StatCounters to work properly on NumPy arrays, to fix the issue reported here (https://issues.apache.org/jira/browse/SPARK-2012).
If NumPy is installed, the NumPy functions ``maximum``, ``minimum``, and ``sqrt``, which work on arrays, are used to merge statistics. If not, we fall back on scalar operators, so it will work on arrays with NumPy, but will also work without NumPy.
New unit tests added, along with a check for NumPy in the tests.
Author: Jeremy Freeman <the.freeman.lab@gmail.com>
Closes #1725 from freeman-lab/numpy-max-statcounter and squashes the following commits:
fe973b1 [Jeremy Freeman] Avoid duplicate array import in tests
7f0e397 [Jeremy Freeman] Refactored check for numpy
8e764dd [Jeremy Freeman] Explicit numpy imports
875414c [Jeremy Freeman] Fixed indents
1c8a832 [Jeremy Freeman] Unit tests for StatCounter with NumPy arrays
176a127 [Jeremy Freeman] Use numpy arrays in StatCounter
Diffstat (limited to 'python/pyspark/statcounter.py')
-rw-r--r-- | python/pyspark/statcounter.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index e287bd3da1..1e597d64e0 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -20,6 +20,13 @@ import copy import math +try: + from numpy import maximum, minimum, sqrt +except ImportError: + maximum = max + minimum = min + sqrt = math.sqrt + class StatCounter(object): @@ -39,10 +46,8 @@ class StatCounter(object): self.n += 1 self.mu += delta / self.n self.m2 += delta * (value - self.mu) - if self.maxValue < value: - self.maxValue = value - if self.minValue > value: - self.minValue = value + self.maxValue = maximum(self.maxValue, value) + self.minValue = minimum(self.minValue, value) return self @@ -70,8 +75,8 @@ class StatCounter(object): else: self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) - self.maxValue = max(self.maxValue, other.maxValue) - self.minValue = min(self.minValue, other.minValue) + self.maxValue = maximum(self.maxValue, other.maxValue) + self.minValue = minimum(self.minValue, other.minValue) self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.n += other.n @@ -115,14 +120,14 @@ class StatCounter(object): # Return the standard deviation of the values. def stdev(self): - return math.sqrt(self.variance()) + return sqrt(self.variance()) # # Return the sample standard deviation of the values, which corrects for bias in estimating the # variance by dividing by N-1 instead of N. # def sampleStdev(self): - return math.sqrt(self.sampleVariance()) + return sqrt(self.sampleVariance()) def __repr__(self): return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % |