diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-05-05 07:53:11 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-05 07:53:11 -0700 |
commit | 5ab652cdb8bef10214edd079502a7f49017579aa (patch) | |
tree | c41047fb05c22525d383b758dac0304d53f982c1 /python | |
parent | c6d1efba29a4235130024fee9f118e6b2cb89ce1 (diff) | |
download | spark-5ab652cdb8bef10214edd079502a7f49017579aa.tar.gz spark-5ab652cdb8bef10214edd079502a7f49017579aa.tar.bz2 spark-5ab652cdb8bef10214edd079502a7f49017579aa.zip |
[SPARK-7202] [MLLIB] [PYSPARK] Add SparseMatrixPickler to SerDe
Utilities for pickling and unpickling SparseMatrices using SerDe
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #5775 from MechCoder/spark-7202 and squashes the following commits:
7e689dc [MechCoder] [SPARK-7202] Add SparseMatrixPickler to SerDe
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/linalg.py | 4 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 3 |
2 files changed, 5 insertions, 2 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index a57c0b3ae0..9f3b0baf9f 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -755,7 +755,7 @@ class SparseMatrix(Matrix): return SparseMatrix, ( self.numRows, self.numCols, self.colPtrs.tostring(), self.rowIndices.tostring(), self.values.tostring(), - self.isTransposed) + int(self.isTransposed)) def __getitem__(self, indices): i, j = indices @@ -801,7 +801,7 @@ class SparseMatrix(Matrix): # TODO: More efficient implementation: def __eq__(self, other): - return np.all(self.toArray == other.toArray) + return np.all(self.toArray() == other.toArray()) class Matrices(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1b008b93bc..1d9c6ebf3b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -92,6 +92,9 @@ class VectorTests(MLlibTestCase): self._test_serialize(SparseVector(4, {1: 1, 3: 2})) self._test_serialize(SparseVector(3, {})) self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) def test_dot(self): sv = SparseVector(4, {1: 1, 3: 2}) |