diff options
author | Jeff Zhang <zjffdu@apache.org> | 2016-08-19 12:38:15 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-08-19 12:38:15 +0100 |
commit | 072acf5e1460d66d4b60b536d5b2ccddeee80794 (patch) | |
tree | 82627c726b931b61da6850b5e4b557d4b62e8bc1 /mllib-local/src | |
parent | 864be9359ae2f8409e6dbc38a7a18593f9cc5692 (diff) | |
download | spark-072acf5e1460d66d4b60b536d5b2ccddeee80794.tar.gz spark-072acf5e1460d66d4b60b536d5b2ccddeee80794.tar.bz2 spark-072acf5e1460d66d4b60b536d5b2ccddeee80794.zip |
[SPARK-16965][MLLIB][PYSPARK] Fix bound checking for SparseVector.
## What changes were proposed in this pull request?
1. In scala, add negative low bound checking and put all the low/upper bound checking in one place
2. In python, add low/upper bound checking of indices.
## How was this patch tested?
unit test added
Author: Jeff Zhang <zjffdu@apache.org>
Closes #14555 from zjffdu/SPARK-16965.
Diffstat (limited to 'mllib-local/src')
-rw-r--r-- | mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala | 34 | ||||
-rw-r--r-- | mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala | 6 |
2 files changed, 25 insertions, 15 deletions
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 0659324aad..2e4a58dc62 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -208,17 +208,7 @@ object Vectors { */ @Since("2.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0, "The size of the requested sparse vector must be greater than 0.") - val (indices, values) = elements.sortBy(_._1).unzip - var prev = -1 - indices.foreach { i => - require(prev < i, s"Found duplicate indices: $i.") - prev = i - } - require(prev < size, s"You may not write an element to index $prev because the declared " + - s"size of your vector is $size") - new SparseVector(size, indices.toArray, values.toArray) } @@ -560,11 +550,25 @@ class SparseVector @Since("2.0.0") ( @Since("2.0.0") val indices: Array[Int], @Since("2.0.0") val values: Array[Double]) extends Vector { - require(indices.length == values.length, "Sparse vectors require that the dimension of the" + - s" indices match the dimension of the values. You provided ${indices.length} indices and " + - s" ${values.length} values.") - require(indices.length <= size, s"You provided ${indices.length} indices and values, " + - s"which exceeds the specified vector size ${size}.") + // validate the data + { + require(size >= 0, "The size of the requested sparse vector must be greater than 0.") + require(indices.length == values.length, "Sparse vectors require that the dimension of the" + + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") + require(indices.length <= size, s"You provided ${indices.length} indices and values, " + + s"which exceeds the specified vector size ${size}.") + + if (indices.nonEmpty) { + require(indices(0) >= 0, s"Found negative index: ${indices(0)}.") + } + var prev = -1 + indices.foreach { i => + require(prev < i, s"Index $i follows $prev and is not strictly increasing") + prev = i + } + require(prev < size, s"Index $prev out of bounds for vector of size $size") + } override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 614be460a4..ea22c2787f 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -72,6 +72,12 @@ class VectorsSuite extends SparkMLFunSuite { } } + test("sparse vector construction with negative indices") { + intercept[IllegalArgumentException] { + Vectors.sparse(3, Array(-1, 1), Array(3.0, 5.0)) + } + } + test("dense to array") { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.toArray.eq(arr)) |