aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-05-05 07:53:11 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-05 07:53:11 -0700
commit5ab652cdb8bef10214edd079502a7f49017579aa (patch)
treec41047fb05c22525d383b758dac0304d53f982c1 /python
parentc6d1efba29a4235130024fee9f118e6b2cb89ce1 (diff)
downloadspark-5ab652cdb8bef10214edd079502a7f49017579aa.tar.gz
spark-5ab652cdb8bef10214edd079502a7f49017579aa.tar.bz2
spark-5ab652cdb8bef10214edd079502a7f49017579aa.zip
[SPARK-7202] [MLLIB] [PYSPARK] Add SparseMatrixPickler to SerDe
Utilities for pickling and unpickling SparseMatrices using SerDe Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5775 from MechCoder/spark-7202 and squashes the following commits: 7e689dc [MechCoder] [SPARK-7202] Add SparseMatrixPickler to SerDe
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/linalg.py4
-rw-r--r--python/pyspark/mllib/tests.py3
2 files changed, 5 insertions, 2 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index a57c0b3ae0..9f3b0baf9f 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -755,7 +755,7 @@ class SparseMatrix(Matrix):
return SparseMatrix, (
self.numRows, self.numCols, self.colPtrs.tostring(),
self.rowIndices.tostring(), self.values.tostring(),
- self.isTransposed)
+ int(self.isTransposed))
def __getitem__(self, indices):
i, j = indices
@@ -801,7 +801,7 @@ class SparseMatrix(Matrix):
# TODO: More efficient implementation:
def __eq__(self, other):
- return np.all(self.toArray == other.toArray)
+ return np.all(self.toArray() == other.toArray())
class Matrices(object):
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 1b008b93bc..1d9c6ebf3b 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -92,6 +92,9 @@ class VectorTests(MLlibTestCase):
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
self._test_serialize(SparseVector(3, {}))
self._test_serialize(DenseMatrix(2, 3, range(6)))
+ sm1 = SparseMatrix(
+ 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
+ self._test_serialize(sm1)
def test_dot(self):
sv = SparseVector(4, {1: 1, 3: 2})