aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/linalg.py')
-rw-r--r--python/pyspark/mllib/linalg.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 4f8491f43e..7f21190ed8 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -510,6 +510,23 @@ class SparseVector(Vector):
and np.array_equal(other.indices, self.indices)
and np.array_equal(other.values, self.values))
+ def __getitem__(self, index):
+ inds = self.indices
+ vals = self.values
+ if not isinstance(index, int):
+ raise ValueError(
+ "Indices must be of type integer, got type %s" % type(index))
+ if index < 0:
+ index += self.size
+ if index >= self.size or index < 0:
+ raise ValueError("Index %d out of bounds." % index)
+
+ insert_index = np.searchsorted(inds, index)
+ row_ind = inds[insert_index]
+ if row_ind == index:
+ return vals[insert_index]
+ return 0.
+
def __ne__(self, other):
return not self.__eq__(other)