diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/mllib/linalg.py | 17 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 12 |
2 files changed, 29 insertions, 0 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 4f8491f43e..7f21190ed8 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -510,6 +510,23 @@ class SparseVector(Vector): and np.array_equal(other.indices, self.indices) and np.array_equal(other.values, self.values)) + def __getitem__(self, index): + inds = self.indices + vals = self.values + if not isinstance(index, int): + raise ValueError( + "Indices must be of type integer, got type %s" % type(index)) + if index < 0: + index += self.size + if index >= self.size or index < 0: + raise ValueError("Index %d out of bounds." % index) + + insert_index = np.searchsorted(inds, index) + row_ind = inds[insert_index] + if row_ind == index: + return vals[insert_index] + return 0. + def __ne__(self, other): return not self.__eq__(other) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1f48bc1219..140c22b5fd 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -120,6 +120,18 @@ class VectorTests(PySparkTestCase): dv = DenseVector(v) self.assertTrue(dv.array.dtype == 'float64') + def test_sparse_vector_indexing(self): + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEquals(sv[0], 0.) + self.assertEquals(sv[3], 2.) + self.assertEquals(sv[1], 1.) + self.assertEquals(sv[2], 0.) + self.assertEquals(sv[-1], 2) + self.assertEquals(sv[-2], 0) + self.assertEquals(sv[-4], 0) + for ind in [4, -5, 7.8]: + self.assertRaises(ValueError, sv.__getitem__, ind) + class ListTests(PySparkTestCase): |