aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorzero323 <matthew.szymkiewicz@gmail.com>2015-10-16 15:53:26 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-10-16 15:53:26 -0700
commit8ac71d62d976bbfd0159cac6816dd8fa580ae1cb (patch)
tree18740525525d04abee55c3173956dd5970365a23 /python
parent10046ea76cf8f0d08fe7ef548e4dbec69d9c73b8 (diff)
downloadspark-8ac71d62d976bbfd0159cac6816dd8fa580ae1cb.tar.gz
spark-8ac71d62d976bbfd0159cac6816dd8fa580ae1cb.tar.bz2
spark-8ac71d62d976bbfd0159cac6816dd8fa580ae1cb.zip
[SPARK-11084] [ML] [PYTHON] Check if index can contain non-zero value before binary search
At this moment `SparseVector.__getitem__` executes `np.searchsorted` first and checks if result is in an expected range after that. It is possible to check if index can contain non-zero value before executing `np.searchsorted`. Author: zero323 <matthew.szymkiewicz@gmail.com> Closes #9098 from zero323/sparse_vector_getitem_improved.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/linalg/__init__.py4
-rw-r--r--python/pyspark/mllib/tests.py10
2 files changed, 12 insertions, 2 deletions
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index 5276eb41cf..ae9ce58450 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -770,10 +770,10 @@ class SparseVector(Vector):
if index < 0:
index += self.size
- insert_index = np.searchsorted(inds, index)
- if insert_index >= inds.size:
+ if (inds.size == 0) or (index > inds.item(-1)):
return 0.
+ insert_index = np.searchsorted(inds, index)
row_ind = inds[insert_index]
if row_ind == index:
return vals[insert_index]
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 2a6a5cd3fe..2ad69a0ab1 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -252,6 +252,16 @@ class VectorTests(MLlibTestCase):
for ind in [7.8, '1']:
self.assertRaises(TypeError, sv.__getitem__, ind)
+ zeros = SparseVector(4, {})
+ self.assertEqual(zeros[0], 0.0)
+ self.assertEqual(zeros[3], 0.0)
+ for ind in [4, -5]:
+ self.assertRaises(ValueError, zeros.__getitem__, ind)
+
+ empty = SparseVector(0, {})
+ for ind in [-1, 0, 1]:
+ self.assertRaises(ValueError, empty.__getitem__, ind)
+
def test_matrix_indexing(self):
mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
expected = [[0, 6], [1, 8], [4, 10]]