aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReza Zadeh <rizlar@gmail.com>2014-01-04 01:52:28 -0800
committerReza Zadeh <rizlar@gmail.com>2014-01-04 01:52:28 -0800
commit73daa700bd2acff7ff196c9262dffb2d8b9354bf (patch)
tree4bbdee9e875e2e447e36b7c21fd393bce97576c5 /mllib
parent26a74f0c4131d506384b94a913b8c6e1a30be9a4 (diff)
downloadspark-73daa700bd2acff7ff196c9262dffb2d8b9354bf.tar.gz
spark-73daa700bd2acff7ff196c9262dffb2d8b9354bf.tar.bz2
spark-73daa700bd2acff7ff196c9262dffb2d8b9354bf.zip
add k parameter
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala3
2 files changed, 13 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
index ac9178e78c..465fc746ed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
@@ -43,9 +43,8 @@ object SVD {
* Then we compute U via easy matrix multiplication
* as U = A * V * S^-1
*
- * Only singular vectors associated with singular values
- * greater or equal to MIN_SVALUE are recovered. If there are k
- * such values, then the dimensions of the return will be:
+ * Only the k largest singular values and associated vectors are found.
+ * If there are k such values, then the dimensions of the return will be:
*
* S is k x k and diagonal, holding the singular values on diagonal
* U is m x k and satisfies U'U = eye(k)
@@ -57,22 +56,22 @@ object SVD {
* @param data RDD Matrix in sparse 1-index format ((int, int), value)
* @param m number of rows
* @param n number of columns
- * @param min_svalue Recover singular values greater or equal to min_svalue
+ * @param k Recover k singular values and vectors
* @return Three sparse matrices: U, S, V such that A = USV^T
*/
def sparseSVD(
data: RDD[MatrixEntry],
m: Int,
n: Int,
- min_svalue: Double)
+ k: Int)
: SVDecomposedMatrix =
{
if (m < n || m <= 0 || n <= 0) {
throw new IllegalArgumentException("Expecting a tall and skinny matrix")
}
- if (min_svalue < 1.0e-8) {
- throw new IllegalArgumentException("Minimum singular value requested is too small")
+ if (k < 1 || k > n) {
+ throw new IllegalArgumentException("Must request up to n singular values")
}
// Compute A^T A, assuming rows are sparse enough to fit in memory
@@ -93,12 +92,13 @@ object SVD {
// Since A^T A is small, we can compute its SVD directly
val svd = Singular.sparseSVD(ata)
val V = svd(0)
- val sigma = MatrixFunctions.sqrt(svd(1)).toArray.filter(x => x >= min_svalue)
+ val sigmas = MatrixFunctions.sqrt(svd(1)).toArray.filter(x => x > 1e-9)
- // threshold s values
- if(sigma.isEmpty) {
- throw new Exception("All singular values are smaller than min_svalue: " + min_svalue)
- }
+ if(sigmas.size < k) {
+ throw new Exception("Not enough singular values to return")
+ }
+
+ val sigma = sigmas.take(k)
val sc = data.sparkContext
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala
index 71749ff729..dc4e9239a2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala
@@ -66,9 +66,8 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {
val n = 3
val data = sc.makeRDD(Array.tabulate(m,n){ (a,b)=>
MatrixEntry(a+1,b+1, (a+2).toDouble*(b+1)/(1+a+b)) }.flatten )
- val min_svalue = 1.0e-8
- val decomposed = SVD.sparseSVD(data, m, n, min_svalue)
+ val decomposed = SVD.sparseSVD(data, m, n, n)
val u = decomposed.U
val v = decomposed.V
val s = decomposed.S