aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c29deb9574..16fb5a9256 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -38,12 +38,19 @@ from pyspark.serializers import read_int
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
+_have_numpy = False
try:
import scipy.sparse
_have_scipy = True
except:
# No SciPy, but that's okay, we'll skip those tests
pass
+try:
+ import numpy as np
+ _have_numpy = True
+except:
+ # No NumPy, but that's okay, we'll skip those tests
+ pass
SPARK_HOME = os.environ["SPARK_HOME"]
@@ -914,9 +921,26 @@ class SciPyTests(PySparkTestCase):
self.assertEqual(expected, observed)
+@unittest.skipIf(not _have_numpy, "NumPy not installed")
+class NumPyTests(PySparkTestCase):
+ """General PySpark tests that depend on numpy """
+
+ def test_statcounter_array(self):
+ x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])])
+ s = x.stats()
+ self.assertSequenceEqual([2.0,2.0], s.mean().tolist())
+ self.assertSequenceEqual([1.0,1.0], s.min().tolist())
+ self.assertSequenceEqual([3.0,3.0], s.max().tolist())
+ self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist())
+
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
+ if not _have_numpy:
+ print "NOTE: Skipping NumPy tests as it does not seem to be installed"
unittest.main()
if not _have_scipy:
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
+ if not _have_numpy:
+ print "NOTE: NumPy tests were skipped as it does not seem to be installed"