aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorGeorge Dittmar <georgedittmar@gmail.com>2015-07-20 08:55:37 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-20 08:55:37 -0700
commit3f7de7db4cf7c5e2824cb91087c5e9d4beb0f738 (patch)
treeeb9dacf87b83b2e1524f45c25f9c68dbb5b3f13a /mllib/src/main
parent79ec07290d0b4d16f1643af83824d926304c8f46 (diff)
downloadspark-3f7de7db4cf7c5e2824cb91087c5e9d4beb0f738.tar.gz
spark-3f7de7db4cf7c5e2824cb91087c5e9d4beb0f738.tar.bz2
spark-3f7de7db4cf7c5e2824cb91087c5e9d4beb0f738.zip
[SPARK-7422] [MLLIB] Add argmax to Vector, SparseVector
Modifying Vector, DenseVector, and SparseVector to implement argmax functionality. This work is to set the stage for changes to be done in Spark-7423. Author: George Dittmar <georgedittmar@gmail.com> Author: George <dittmar@Georges-MacBook-Pro.local> Author: dittmarg <george.dittmar@webtrends.com> Author: Xiangrui Meng <meng@databricks.com> Closes #6112 from GeorgeDittmar/SPARK-7422 and squashes the following commits: 3e0a939 [George Dittmar] Merge pull request #1 from mengxr/SPARK-7422 127dec5 [Xiangrui Meng] update argmax impl 2ea6a55 [George Dittmar] Added MimaExcludes for Vectors.argmax 98058f4 [George Dittmar] Merge branch 'master' of github.com:apache/spark into SPARK-7422 5fd9380 [George Dittmar] fixing style check error 42341fb [George Dittmar] refactoring arg max check to better handle zero values b22af46 [George Dittmar] Fixing spaces between commas in unit test f2eba2f [George Dittmar] Cleaning up unit tests to be fewer lines aa330e3 [George Dittmar] Fixing some last if else spacing issues ac53c55 [George Dittmar] changing dense vector argmax unit test to be one line call vs 2 d5b5423 [George Dittmar] Fixing code style and updating if logic on when to check for zero values ee1a85a [George Dittmar] Cleaning up unit tests a bit and modifying a few cases 3ee8711 [George Dittmar] Fixing corner case issue with zeros in the active values of the sparse vector. Updated unit tests b1f059f [George Dittmar] Added comment before we start arg max calculation. Updated unit tests to cover corner cases f21dcce [George Dittmar] commit af17981 [dittmarg] Initial work fixing bug that was made clear in pr eeda560 [George] Fixing SparseVector argmax function to ignore zero values while doing the calculation. 4526acc [George] Merge branch 'master' of github.com:apache/spark into SPARK-7422 df9538a [George] Added argmax to sparse vector and added unit test 3cffed4 [George] Adding unit tests for argmax functions for Dense and Sparse vectors 04677af [George] initial work on adding argmax to Vector and SparseVector
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala57
1 files changed, 52 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index e048b01d92..9067b3ba9a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
toDense
}
}
+
+ /**
+ * Find the index of a maximal element. Returns the first maximal element in case of a tie.
+ * Returns -1 if vector has length 0.
+ */
+ def argmax: Int
}
/**
@@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
new SparseVector(size, ii, vv)
}
- /**
- * Find the index of a maximal element. Returns the first maximal element in case of a tie.
- * Returns -1 if vector has length 0.
- */
- private[spark] def argmax: Int = {
+ override def argmax: Int = {
if (size == 0) {
-1
} else {
@@ -717,6 +719,51 @@ class SparseVector(
new SparseVector(size, ii, vv)
}
}
+
+ override def argmax: Int = {
+ if (size == 0) {
+ -1
+ } else {
+ // Find the max active entry.
+ var maxIdx = indices(0)
+ var maxValue = values(0)
+ var maxJ = 0
+ var j = 1
+ val na = numActives
+ while (j < na) {
+ val v = values(j)
+ if (v > maxValue) {
+ maxValue = v
+ maxIdx = indices(j)
+ maxJ = j
+ }
+ j += 1
+ }
+
+ // If the max active entry is nonpositive and there exists inactive ones, find the first zero.
+ if (maxValue <= 0.0 && na < size) {
+ if (maxValue == 0.0) {
+ // If there exists an inactive entry before maxIdx, find it and return its index.
+ if (maxJ < maxIdx) {
+ var k = 0
+ while (k < maxJ && indices(k) == k) {
+ k += 1
+ }
+ maxIdx = k
+ }
+ } else {
+ // If the max active value is negative, find and return the first inactive index.
+ var k = 0
+ while (k < na && indices(k) == k) {
+ k += 1
+ }
+ maxIdx = k
+ }
+ }
+
+ maxIdx
+ }
+ }
}
object SparseVector {