diff options
-rw-r--r-- | python/pyspark/mllib/linalg.py | 2 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 10 |
2 files changed, 11 insertions, 1 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f7aa2b0cb0..4f8491f43e 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -178,7 +178,7 @@ class DenseVector(Vector): elif not isinstance(ar, np.ndarray): ar = np.array(ar, dtype=np.float64) if ar.dtype != np.float64: - ar.astype(np.float64) + ar = ar.astype(np.float64) self.array = ar def __reduce__(self): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5034f229e8..1f48bc1219 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -110,6 +110,16 @@ class VectorTests(PySparkTestCase): self.assertEquals(0.0, _squared_distance(dv, dv)) self.assertEquals(0.0, _squared_distance(lst, lst)) + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + class ListTests(PySparkTestCase): |