aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Zhang <zjffdu@apache.org>2016-08-19 12:38:15 +0100
committerSean Owen <sowen@cloudera.com>2016-08-19 12:38:15 +0100
commit072acf5e1460d66d4b60b536d5b2ccddeee80794 (patch)
tree82627c726b931b61da6850b5e4b557d4b62e8bc1
parent864be9359ae2f8409e6dbc38a7a18593f9cc5692 (diff)
downloadspark-072acf5e1460d66d4b60b536d5b2ccddeee80794.tar.gz
spark-072acf5e1460d66d4b60b536d5b2ccddeee80794.tar.bz2
spark-072acf5e1460d66d4b60b536d5b2ccddeee80794.zip
[SPARK-16965][MLLIB][PYSPARK] Fix bound checking for SparseVector.
## What changes were proposed in this pull request? 1. In scala, add negative low bound checking and put all the low/upper bound checking in one place 2. In python, add low/upper bound checking of indices. ## How was this patch tested? unit test added Author: Jeff Zhang <zjffdu@apache.org> Closes #14555 from zjffdu/SPARK-16965.
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala34
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala6
-rw-r--r--python/pyspark/ml/linalg/__init__.py15
3 files changed, 40 insertions, 15 deletions
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 0659324aad..2e4a58dc62 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -208,17 +208,7 @@ object Vectors {
*/
@Since("2.0.0")
def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {
- require(size > 0, "The size of the requested sparse vector must be greater than 0.")
-
val (indices, values) = elements.sortBy(_._1).unzip
- var prev = -1
- indices.foreach { i =>
- require(prev < i, s"Found duplicate indices: $i.")
- prev = i
- }
- require(prev < size, s"You may not write an element to index $prev because the declared " +
- s"size of your vector is $size")
-
new SparseVector(size, indices.toArray, values.toArray)
}
@@ -560,11 +550,25 @@ class SparseVector @Since("2.0.0") (
@Since("2.0.0") val indices: Array[Int],
@Since("2.0.0") val values: Array[Double]) extends Vector {
- require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
- s" indices match the dimension of the values. You provided ${indices.length} indices and " +
- s" ${values.length} values.")
- require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
- s"which exceeds the specified vector size ${size}.")
+ // validate the data
+ {
+ require(size >= 0, "The size of the requested sparse vector must be greater than 0.")
+ require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
+ s" indices match the dimension of the values. You provided ${indices.length} indices and " +
+ s" ${values.length} values.")
+ require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
+ s"which exceeds the specified vector size ${size}.")
+
+ if (indices.nonEmpty) {
+ require(indices(0) >= 0, s"Found negative index: ${indices(0)}.")
+ }
+ var prev = -1
+ indices.foreach { i =>
+ require(prev < i, s"Index $i follows $prev and is not strictly increasing")
+ prev = i
+ }
+ require(prev < size, s"Index $prev out of bounds for vector of size $size")
+ }
override def toString: String =
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
index 614be460a4..ea22c2787f 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
@@ -72,6 +72,12 @@ class VectorsSuite extends SparkMLFunSuite {
}
}
+ test("sparse vector construction with negative indices") {
+ intercept[IllegalArgumentException] {
+ Vectors.sparse(3, Array(-1, 1), Array(3.0, 5.0))
+ }
+ }
+
test("dense to array") {
val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
assert(vec.toArray.eq(arr))
diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py
index f42c589b92..05c0ac862f 100644
--- a/python/pyspark/ml/linalg/__init__.py
+++ b/python/pyspark/ml/linalg/__init__.py
@@ -478,6 +478,14 @@ class SparseVector(Vector):
SparseVector(4, {1: 1.0, 3: 5.5})
>>> SparseVector(4, [1, 3], [1.0, 5.5])
SparseVector(4, {1: 1.0, 3: 5.5})
+ >>> SparseVector(4, {1:1.0, 6:2.0})
+ Traceback (most recent call last):
+ ...
+ AssertionError: Index 6 is out of the the size of vector with size=4
+ >>> SparseVector(4, {-1:1.0})
+ Traceback (most recent call last):
+ ...
+ AssertionError: Contains negative index -1
"""
self.size = int(size)
""" Size of the vector. """
@@ -511,6 +519,13 @@ class SparseVector(Vector):
"Indices %s and %s are not strictly increasing"
% (self.indices[i], self.indices[i + 1]))
+ if self.indices.size > 0:
+ assert np.max(self.indices) < self.size, \
+ "Index %d is out of the the size of vector with size=%d" \
+ % (np.max(self.indices), self.size)
+ assert np.min(self.indices) >= 0, \
+ "Contains negative index %d" % (np.min(self.indices))
+
def numNonzeros(self):
"""
Number of nonzero elements. This scans all active values and count non zeros.