aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/mllib-guide.md5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala3
3 files changed, 16 insertions, 16 deletions
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index 711187fbea..abeb55d081 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -251,9 +251,10 @@ val data = sc.textFile("mllib/data/als/test.data").map { line =>
}
val m = 4
val n = 4
+val k = 1
-// recover singular vectors for singular values at or above 1e-5
-val (u, s, v) = SVD.sparseSVD(data, m, n, 1e-5)
+// recover largest singular vector
+val (u, s, v) = SVD.sparseSVD(data, m, n, 1)
println("singular values = " + s.toArray.mkString)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
index ac9178e78c..465fc746ed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
@@ -43,9 +43,8 @@ object SVD {
* Then we compute U via easy matrix multiplication
* as U = A * V * S^-1
*
- * Only singular vectors associated with singular values
- * greater or equal to MIN_SVALUE are recovered. If there are k
- * such values, then the dimensions of the return will be:
+ * Only the k largest singular values and associated vectors are found.
+ * If there are k such values, then the dimensions of the return will be:
*
* S is k x k and diagonal, holding the singular values on diagonal
* U is m x k and satisfies U'U = eye(k)
@@ -57,22 +56,22 @@ object SVD {
* @param data RDD Matrix in sparse 1-index format ((int, int), value)
* @param m number of rows
* @param n number of columns
- * @param min_svalue Recover singular values greater or equal to min_svalue
+ * @param k Recover k singular values and vectors
* @return Three sparse matrices: U, S, V such that A = USV^T
*/
def sparseSVD(
data: RDD[MatrixEntry],
m: Int,
n: Int,
- min_svalue: Double)
+ k: Int)
: SVDecomposedMatrix =
{
if (m < n || m <= 0 || n <= 0) {
throw new IllegalArgumentException("Expecting a tall and skinny matrix")
}
- if (min_svalue < 1.0e-8) {
- throw new IllegalArgumentException("Minimum singular value requested is too small")
+ if (k < 1 || k > n) {
+ throw new IllegalArgumentException("Must request up to n singular values")
}
// Compute A^T A, assuming rows are sparse enough to fit in memory
@@ -93,12 +92,13 @@ object SVD {
// Since A^T A is small, we can compute its SVD directly
val svd = Singular.sparseSVD(ata)
val V = svd(0)
- val sigma = MatrixFunctions.sqrt(svd(1)).toArray.filter(x => x >= min_svalue)
+ val sigmas = MatrixFunctions.sqrt(svd(1)).toArray.filter(x => x > 1e-9)
- // threshold s values
- if(sigma.isEmpty) {
- throw new Exception("All singular values are smaller than min_svalue: " + min_svalue)
- }
+ if(sigmas.size < k) {
+ throw new Exception("Not enough singular values to return")
+ }
+
+ val sigma = sigmas.take(k)
val sc = data.sparkContext
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala
index 71749ff729..dc4e9239a2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala
@@ -66,9 +66,8 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {
val n = 3
val data = sc.makeRDD(Array.tabulate(m,n){ (a,b)=>
MatrixEntry(a+1,b+1, (a+2).toDouble*(b+1)/(1+a+b)) }.flatten )
- val min_svalue = 1.0e-8
- val decomposed = SVD.sparseSVD(data, m, n, min_svalue)
+ val decomposed = SVD.sparseSVD(data, m, n, n)
val u = decomposed.U
val v = decomposed.V
val s = decomposed.S