diff options
author | Jeff Zhang <zjffdu@apache.org> | 2016-08-19 12:38:15 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-08-19 12:38:15 +0100 |
commit | 072acf5e1460d66d4b60b536d5b2ccddeee80794 (patch) | |
tree | 82627c726b931b61da6850b5e4b557d4b62e8bc1 /python/pyspark/ml/linalg | |
parent | 864be9359ae2f8409e6dbc38a7a18593f9cc5692 (diff) | |
download | spark-072acf5e1460d66d4b60b536d5b2ccddeee80794.tar.gz spark-072acf5e1460d66d4b60b536d5b2ccddeee80794.tar.bz2 spark-072acf5e1460d66d4b60b536d5b2ccddeee80794.zip |
[SPARK-16965][MLLIB][PYSPARK] Fix bound checking for SparseVector.
## What changes were proposed in this pull request?
1. In scala, add negative low bound checking and put all the low/upper bound checking in one place
2. In python, add low/upper bound checking of indices.
## How was this patch tested?
unit test added
Author: Jeff Zhang <zjffdu@apache.org>
Closes #14555 from zjffdu/SPARK-16965.
Diffstat (limited to 'python/pyspark/ml/linalg')
-rw-r--r-- | python/pyspark/ml/linalg/__init__.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index f42c589b92..05c0ac862f 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -478,6 +478,14 @@ class SparseVector(Vector): SparseVector(4, {1: 1.0, 3: 5.5}) >>> SparseVector(4, [1, 3], [1.0, 5.5]) SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, {1:1.0, 6:2.0}) + Traceback (most recent call last): + ... + AssertionError: Index 6 is out of the the size of vector with size=4 + >>> SparseVector(4, {-1:1.0}) + Traceback (most recent call last): + ... + AssertionError: Contains negative index -1 """ self.size = int(size) """ Size of the vector. """ @@ -511,6 +519,13 @@ class SparseVector(Vector): "Indices %s and %s are not strictly increasing" % (self.indices[i], self.indices[i + 1])) + if self.indices.size > 0: + assert np.max(self.indices) < self.size, \ + "Index %d is out of the the size of vector with size=%d" \ + % (np.max(self.indices), self.size) + assert np.min(self.indices) >= 0, \ + "Contains negative index %d" % (np.min(self.indices)) + def numNonzeros(self): """ Number of nonzero elements. This scans all active values and count non zeros. |