diff options
author | Uri Laserson <laserson@cloudera.com> | 2014-05-31 14:59:09 -0700 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-05-31 14:59:09 -0700 |
commit | 5e98967b612ccf026cb1cc5ff3ac8bf72d7e836e (patch) | |
tree | d0b70fa8713defa5227e4c882ae13b1b7357ce67 | |
parent | d8c005d5371f81a2a06c5d27c7021e1ae43d7193 (diff) | |
download | spark-5e98967b612ccf026cb1cc5ff3ac8bf72d7e836e.tar.gz spark-5e98967b612ccf026cb1cc5ff3ac8bf72d7e836e.tar.bz2 spark-5e98967b612ccf026cb1cc5ff3ac8bf72d7e836e.zip |
SPARK-1917: fix PySpark import of scipy.special functions
https://issues.apache.org/jira/browse/SPARK-1917
Author: Uri Laserson <laserson@cloudera.com>
Closes #866 from laserson/SPARK-1917 and squashes the following commits:
d947e8c [Uri Laserson] Added test for scipy.special importing
1798bbd [Uri Laserson] SPARK-1917: fix PySpark import of scipy.special
-rw-r--r-- | python/pyspark/cloudpickle.py | 2 | ||||
-rw-r--r-- | python/pyspark/tests.py | 24 |
2 files changed, 25 insertions, 1 deletions
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 6a7c23a069..eb5dbb8de2 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -933,7 +933,7 @@ def _change_cell_value(cell, newval): Note: These can never be renamed due to client compatibility issues""" def _getobject(modname, attribute): - mod = __import__(modname) + mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] def _generateImage(size, mode, str_rep): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 64f2eeb12b..ed90915fcd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -35,6 +35,14 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int +_have_scipy = False +try: + import scipy.sparse + _have_scipy = True +except: + # No SciPy, but that's okay, we'll skip those tests + pass + SPARK_HOME = os.environ["SPARK_HOME"] @@ -359,5 +367,21 @@ class TestSparkSubmit(unittest.TestCase): self.assertIn("[2, 4, 6]", out) +@unittest.skipIf(not _have_scipy, "SciPy not installed") +class SciPyTests(PySparkTestCase): + """General PySpark tests that depend on scipy """ + + def test_serialize(self): + from scipy.special import gammaln + x = range(1, 5) + expected = map(gammaln, x) + observed = self.sc.parallelize(x).map(gammaln).collect() + self.assertEqual(expected, observed) + + if __name__ == "__main__": + if not _have_scipy: + print "NOTE: Skipping SciPy tests as it does not seem to be installed" unittest.main() + if not _have_scipy: + print "NOTE: SciPy tests were skipped as it does not seem to be installed" |