diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/linalg.py | 4 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 3 |
2 files changed, 5 insertions, 2 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index a57c0b3ae0..9f3b0baf9f 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -755,7 +755,7 @@ class SparseMatrix(Matrix): return SparseMatrix, ( self.numRows, self.numCols, self.colPtrs.tostring(), self.rowIndices.tostring(), self.values.tostring(), - self.isTransposed) + int(self.isTransposed)) def __getitem__(self, indices): i, j = indices @@ -801,7 +801,7 @@ class SparseMatrix(Matrix): # TODO: More efficient implementation: def __eq__(self, other): - return np.all(self.toArray == other.toArray) + return np.all(self.toArray() == other.toArray()) class Matrices(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1b008b93bc..1d9c6ebf3b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -92,6 +92,9 @@ class VectorTests(MLlibTestCase): self._test_serialize(SparseVector(4, {1: 1, 3: 2})) self._test_serialize(SparseVector(3, {})) self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) def test_dot(self): sv = SparseVector(4, {1: 1, 3: 2}) |