aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-01-14 11:03:11 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-14 11:03:11 -0800
commit5840f5464bad8431810d459c97d6e4635eea175c (patch)
treed7a60d640025f77a1c1ae06703ebce8359bd754c /python/pyspark/mllib
parent38bdc992a1a0485ac630af500da54f0a77e133bf (diff)
downloadspark-5840f5464bad8431810d459c97d6e4635eea175c.tar.gz
spark-5840f5464bad8431810d459c97d6e4635eea175c.tar.bz2
spark-5840f5464bad8431810d459c97d6e4635eea175c.zip
[SPARK-2909] [MLlib] [PySpark] SparseVector in pyspark now supports indexing
Slightly different than the scala code which converts the sparsevector into a densevector and then checks the index. I also hope I've added tests in the right place. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4025 from MechCoder/spark-2909 and squashes the following commits: 07d0f26 [MechCoder] STY: Rename item to index f02148b [MechCoder] [SPARK-2909] [Mlib] SparseVector in pyspark now supports indexing
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r--python/pyspark/mllib/linalg.py17
-rw-r--r--python/pyspark/mllib/tests.py12
2 files changed, 29 insertions, 0 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 4f8491f43e..7f21190ed8 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -510,6 +510,23 @@ class SparseVector(Vector):
and np.array_equal(other.indices, self.indices)
and np.array_equal(other.values, self.values))
+ def __getitem__(self, index):
+ inds = self.indices
+ vals = self.values
+ if not isinstance(index, int):
+ raise ValueError(
+ "Indices must be of type integer, got type %s" % type(index))
+ if index < 0:
+ index += self.size
+ if index >= self.size or index < 0:
+ raise ValueError("Index %d out of bounds." % index)
+
+ insert_index = np.searchsorted(inds, index)
+ row_ind = inds[insert_index]
+ if row_ind == index:
+ return vals[insert_index]
+ return 0.
+
def __ne__(self, other):
return not self.__eq__(other)
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 1f48bc1219..140c22b5fd 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -120,6 +120,18 @@ class VectorTests(PySparkTestCase):
dv = DenseVector(v)
self.assertTrue(dv.array.dtype == 'float64')
+ def test_sparse_vector_indexing(self):
+ sv = SparseVector(4, {1: 1, 3: 2})
+ self.assertEquals(sv[0], 0.)
+ self.assertEquals(sv[3], 2.)
+ self.assertEquals(sv[1], 1.)
+ self.assertEquals(sv[2], 0.)
+ self.assertEquals(sv[-1], 2)
+ self.assertEquals(sv[-2], 0)
+ self.assertEquals(sv[-4], 0)
+ for ind in [4, -5, 7.8]:
+ self.assertRaises(ValueError, sv.__getitem__, ind)
+
class ListTests(PySparkTestCase):