aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala57
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala39
-rw-r--r--project/MimaExcludes.scala4
3 files changed, 95 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index e048b01d92..9067b3ba9a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
toDense
}
}
+
+ /**
+ * Find the index of a maximal element. Returns the first maximal element in case of a tie.
+ * Returns -1 if vector has length 0.
+ */
+ def argmax: Int
}
/**
@@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
new SparseVector(size, ii, vv)
}
- /**
- * Find the index of a maximal element. Returns the first maximal element in case of a tie.
- * Returns -1 if vector has length 0.
- */
- private[spark] def argmax: Int = {
+ override def argmax: Int = {
if (size == 0) {
-1
} else {
@@ -717,6 +719,51 @@ class SparseVector(
new SparseVector(size, ii, vv)
}
}
+
+ override def argmax: Int = {
+ if (size == 0) {
+ -1
+ } else {
+ // Find the max active entry.
+ var maxIdx = indices(0)
+ var maxValue = values(0)
+ var maxJ = 0
+ var j = 1
+ val na = numActives
+ while (j < na) {
+ val v = values(j)
+ if (v > maxValue) {
+ maxValue = v
+ maxIdx = indices(j)
+ maxJ = j
+ }
+ j += 1
+ }
+
+ // If the max active entry is nonpositive and there exists inactive ones, find the first zero.
+ if (maxValue <= 0.0 && na < size) {
+ if (maxValue == 0.0) {
+ // If there exists an inactive entry before maxIdx, find it and return its index.
+ if (maxJ < maxIdx) {
+ var k = 0
+ while (k < maxJ && indices(k) == k) {
+ k += 1
+ }
+ maxIdx = k
+ }
+ } else {
+ // If the max active value is negative, find and return the first inactive index.
+ var k = 0
+ while (k < na && indices(k) == k) {
+ k += 1
+ }
+ maxIdx = k
+ }
+ }
+
+ maxIdx
+ }
+ }
}
object SparseVector {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 178d95a7b9..03be4119bd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -62,11 +62,50 @@ class VectorsSuite extends SparkFunSuite with Logging {
assert(vec.toArray.eq(arr))
}
+ test("dense argmax") {
+ val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
+ assert(vec.argmax === -1)
+
+ val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
+ assert(vec2.argmax === 3)
+
+ val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
+ assert(vec3.argmax === 3)
+ }
+
test("sparse to array") {
val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
assert(vec.toArray === arr)
}
+ test("sparse argmax") {
+ val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
+ assert(vec.argmax === -1)
+
+ val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
+ assert(vec2.argmax === 3)
+
+ val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7))
+ assert(vec3.argmax === 2)
+
+ // check for case that sparse vector is created with
+ // only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
+ val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7))
+ assert(vec4.argmax === 0)
+
+ val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0))
+ assert(vec5.argmax === 1)
+
+ val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0))
+ assert(vec6.argmax === 2)
+
+ val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7))
+ assert(vec7.argmax === 1)
+
+ val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
+ assert(vec8.argmax === 0)
+ }
+
test("vector equals") {
val dv1 = Vectors.dense(arr.clone())
val dv2 = Vectors.dense(arr.clone())
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 36417f5df9..dd85254749 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -98,6 +98,10 @@ object MimaExcludes {
"org.apache.spark.api.r.StringRRDD.this"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.r.BaseRRDD.this")
+ ) ++ Seq(
+ // SPARK-7422 add argmax for sparse vectors
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Vector.argmax")
)
case v if v.startsWith("1.4") =>