diff options
author | lee19 <lee19@live.co.kr> | 2015-06-30 14:08:00 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-06-30 14:08:09 -0700 |
commit | e7dafd433eee089afcc0eeab02295deae90be7d3 (patch) | |
tree | bbc2c24391c1f81bea7d3a8acf78cdc2aa96beb4 | |
parent | 1b5439f6e809da7389993244f484692fb9ffb43f (diff) | |
download | spark-e7dafd433eee089afcc0eeab02295deae90be7d3.tar.gz spark-e7dafd433eee089afcc0eeab02295deae90be7d3.tar.bz2 spark-e7dafd433eee089afcc0eeab02295deae90be7d3.zip |
[SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k
I'm sorry that I made https://github.com/apache/spark/pull/6949 closed by mistake.
I pushed codes again.
And, I added a test code.
>
There is a bug that `U.numCols() = self.nCols` in `IndexedRowMatrix.computeSVD()`
It should have been `U.numCols() = k = svd.U.numCols()`
>
```
self = U * sigma * V.transpose
(m x n) = (m x n) * (k x k) * (k x n) //ASIS
-->
(m x n) = (m x k) * (k x k) * (k x n) //TOBE
```
Author: lee19 <lee19@live.co.kr>
Closes #6953 from lee19/MLlibBugfix and squashes the following commits:
c1812a0 [lee19] [SPARK-8563] [MLlib] Used nRows instead of numRows() to reduce a burden.
4b9803b [lee19] [SPARK-8563] [MLlib] Fixed a build error.
c2ccd89 [lee19] Added a unit test that validates matrix sizes of svd for [SPARK-8563][MLlib]
8373424 [lee19] [SPARK-8563][MLlib] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k
(cherry picked from commit e72526227fdcf93b7a33375ef954746ac08753f5)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala | 2 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala | 11 |
2 files changed, 12 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 3be530fa07..1c33b43ea7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -146,7 +146,7 @@ class IndexedRowMatrix( val indexedRows = indices.zip(svd.U.rows).map { case (i, v) => IndexedRow(i, v) } - new IndexedRowMatrix(indexedRows, nRows, nCols) + new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt) } else { null } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 4a7b99a976..0ecb7a221a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate matrix sizes of svd") { + val k = 2 + val A = new IndexedRowMatrix(indexedRows) + val svd = A.computeSVD(k, computeU = true) + assert(svd.U.numRows() === m) + assert(svd.U.numCols() === k) + assert(svd.s.size === k) + assert(svd.V.numRows === n) + assert(svd.V.numCols === k) + } + test("validate k in svd") { val A = new IndexedRowMatrix(indexedRows) intercept[IllegalArgumentException] { |