aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorlee19 <lee19@live.co.kr>2015-06-30 14:08:00 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-30 14:08:00 -0700
commite72526227fdcf93b7a33375ef954746ac08753f5 (patch)
tree7d34d9b23897c088254cd83c753b7b0ed0ded3f6 /mllib
parent8c898964f095fcb5bb1c9212e1e484b1eb55c296 (diff)
downloadspark-e72526227fdcf93b7a33375ef954746ac08753f5.tar.gz
spark-e72526227fdcf93b7a33375ef954746ac08753f5.tar.bz2
spark-e72526227fdcf93b7a33375ef954746ac08753f5.zip
[SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k
I'm sorry that I made https://github.com/apache/spark/pull/6949 closed by mistake. I pushed codes again. And, I added a test code. > There is a bug that `U.numCols() = self.nCols` in `IndexedRowMatrix.computeSVD()` It should have been `U.numCols() = k = svd.U.numCols()` > ``` self = U * sigma * V.transpose (m x n) = (m x n) * (k x k) * (k x n) //ASIS --> (m x n) = (m x k) * (k x k) * (k x n) //TOBE ``` Author: lee19 <lee19@live.co.kr> Closes #6953 from lee19/MLlibBugfix and squashes the following commits: c1812a0 [lee19] [SPARK-8563] [MLlib] Used nRows instead of numRows() to reduce a burden. 4b9803b [lee19] [SPARK-8563] [MLlib] Fixed a build error. c2ccd89 [lee19] Added a unit test that validates matrix sizes of svd for [SPARK-8563][MLlib] 8373424 [lee19] [SPARK-8563][MLlib] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala11
2 files changed, 12 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index 3be530fa07..1c33b43ea7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -146,7 +146,7 @@ class IndexedRowMatrix(
val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
IndexedRow(i, v)
}
- new IndexedRowMatrix(indexedRows, nRows, nCols)
+ new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt)
} else {
null
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 4a7b99a976..0ecb7a221a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(closeToZero(U * brzDiag(s) * V.t - localA))
}
+ test("validate matrix sizes of svd") {
+ val k = 2
+ val A = new IndexedRowMatrix(indexedRows)
+ val svd = A.computeSVD(k, computeU = true)
+ assert(svd.U.numRows() === m)
+ assert(svd.U.numCols() === k)
+ assert(svd.s.size === k)
+ assert(svd.V.numRows === n)
+ assert(svd.V.numCols === k)
+ }
+
test("validate k in svd") {
val A = new IndexedRowMatrix(indexedRows)
intercept[IllegalArgumentException] {