aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzero323 <matthew.szymkiewicz@gmail.com>2015-10-08 18:34:15 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-10-08 18:34:15 -0700
commit8e67882b905683a1f151679214ef0b575e77c7e1 (patch)
tree60b28caba87f0d2562418603d4f41c7b08532023
parent3390b400d04e40f767d8a51f1078fcccb4e64abd (diff)
downloadspark-8e67882b905683a1f151679214ef0b575e77c7e1.tar.gz
spark-8e67882b905683a1f151679214ef0b575e77c7e1.tar.bz2
spark-8e67882b905683a1f151679214ef0b575e77c7e1.zip
[SPARK-10973] [ML] [PYTHON] __gettitem__ method throws IndexError exception when we…
__gettitem__ method throws IndexError exception when we try to access index after the last non-zero entry from pyspark.mllib.linalg import Vectors sv = Vectors.sparse(5, {1: 3}) sv[0] ## 0.0 sv[1] ## 3.0 sv[2] ## Traceback (most recent call last): ## File "<stdin>", line 1, in <module> ## File "/python/pyspark/mllib/linalg/__init__.py", line 734, in __getitem__ ## row_ind = inds[insert_index] ## IndexError: index out of bounds Author: zero323 <matthew.szymkiewicz@gmail.com> Closes #9009 from zero323/sparse_vector_index_error.
-rw-r--r--python/pyspark/mllib/linalg/__init__.py3
-rw-r--r--python/pyspark/mllib/tests.py12
2 files changed, 10 insertions, 5 deletions
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index ea42127f16..d903b9030d 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -770,6 +770,9 @@ class SparseVector(Vector):
raise ValueError("Index %d out of bounds." % index)
insert_index = np.searchsorted(inds, index)
+ if insert_index >= inds.size:
+ return 0.
+
row_ind = inds[insert_index]
if row_ind == index:
return vals[insert_index]
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 96cf13495a..2a6a5cd3fe 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -237,15 +237,17 @@ class VectorTests(MLlibTestCase):
self.assertTrue(dv.array.dtype == 'float64')
def test_sparse_vector_indexing(self):
- sv = SparseVector(4, {1: 1, 3: 2})
+ sv = SparseVector(5, {1: 1, 3: 2})
self.assertEqual(sv[0], 0.)
self.assertEqual(sv[3], 2.)
self.assertEqual(sv[1], 1.)
self.assertEqual(sv[2], 0.)
- self.assertEqual(sv[-1], 2)
- self.assertEqual(sv[-2], 0)
- self.assertEqual(sv[-4], 0)
- for ind in [4, -5]:
+ self.assertEqual(sv[4], 0.)
+ self.assertEqual(sv[-1], 0.)
+ self.assertEqual(sv[-2], 2.)
+ self.assertEqual(sv[-3], 0.)
+ self.assertEqual(sv[-5], 0.)
+ for ind in [5, -6]:
self.assertRaises(ValueError, sv.__getitem__, ind)
for ind in [7.8, '1']:
self.assertRaises(TypeError, sv.__getitem__, ind)