aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-04-08 23:01:15 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-08 23:01:15 -0700
commit9689b663a2a4947ad60795321c770052f3c637f1 (patch)
treef2647f7b1ae3a3d11d3ecb29e764214b7cb589ca /mllib/src/main
parentfa0524fd02eedd0bbf1edc750dc3997a86ea25f5 (diff)
downloadspark-9689b663a2a4947ad60795321c770052f3c637f1.tar.gz
spark-9689b663a2a4947ad60795321c770052f3c637f1.tar.bz2
spark-9689b663a2a4947ad60795321c770052f3c637f1.zip
[SPARK-1390] Refactoring of matrices backed by RDDs
This is to refactor interfaces for matrices backed by RDDs. It would be better if we have a clear separation of local matrices and those backed by RDDs. Right now, we have 1. `org.apache.spark.mllib.linalg.SparseMatrix`, which is a wrapper over an RDD of matrix entries, i.e., coordinate list format. 2. `org.apache.spark.mllib.linalg.TallSkinnyDenseMatrix`, which is a wrapper over RDD[Array[Double]], i.e. row-oriented format. We will see naming collision when we introduce local `SparseMatrix`, and the name `TallSkinnyDenseMatrix` is not exact if we switch to `RDD[Vector]` from `RDD[Array[Double]]`. It would be better to have "RDD" in the class name to suggest that operations may trigger jobs. The proposed names are (all under `org.apache.spark.mllib.linalg.rdd`): 1. `RDDMatrix`: trait for matrices backed by one or more RDDs 2. `CoordinateRDDMatrix`: wrapper of `RDD[(Long, Long, Double)]` 3. `RowRDDMatrix`: wrapper of `RDD[Vector]` whose rows do not have special ordering 4. `IndexedRowRDDMatrix`: wrapper of `RDD[(Long, Vector)]` whose rows are associated with indices The current code also introduces local matrices. Author: Xiangrui Meng <meng@databricks.com> Closes #296 from mengxr/mat and squashes the following commits: 24d8294 [Xiangrui Meng] fix for groupBy returning Iterable bfc2b26 [Xiangrui Meng] merge master 8e4f1f5 [Xiangrui Meng] Merge branch 'master' into mat 0135193 [Xiangrui Meng] address Reza's comments 03cd7e1 [Xiangrui Meng] add pca/gram to IndexedRowMatrix add toBreeze to DistributedMatrix for test simplify tests b177ff1 [Xiangrui Meng] address Matei's comments be119fe [Xiangrui Meng] rename m/n to numRows/numCols for local matrix add tests for matrices b881506 [Xiangrui Meng] rename SparkPCA/SVD to TallSkinnyPCA/SVD e7d0d4a [Xiangrui Meng] move IndexedRDDMatrixRow to IndexedRowRDDMatrix 0d1491c [Xiangrui Meng] fix test errors a85262a [Xiangrui Meng] rename RDDMatrixRow to IndexedRDDMatrixRow b8b6ac3 [Xiangrui Meng] Remove old code 4cf679c [Xiangrui Meng] port pca to RowRDDMatrix, and add multiply and covariance 7836e2f [Xiangrui Meng] initial refactoring of matrices backed by RDDs
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala101
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixSVD.scala29
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/PCA.scala120
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala395
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixRow.scala)9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/SparseMatrix.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyDenseMatrix.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyMatrixSVD.scala31
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala112
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixEntry.scala)24
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala148
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala344
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/LAUtils.scala67
13 files changed, 724 insertions, 716 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
new file mode 100644
index 0000000000..b11ba5d30f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import breeze.linalg.{Matrix => BM, DenseMatrix => BDM}
+
+/**
+ * Trait for a local matrix.
+ */
+trait Matrix extends Serializable {
+
+ /** Number of rows. */
+ def numRows: Int
+
+ /** Number of columns. */
+ def numCols: Int
+
+ /** Converts to a dense array in column major. */
+ def toArray: Array[Double]
+
+ /** Converts to a breeze matrix. */
+ private[mllib] def toBreeze: BM[Double]
+
+ /** Gets the (i, j)-th element. */
+ private[mllib] def apply(i: Int, j: Int): Double = toBreeze(i, j)
+
+ override def toString: String = toBreeze.toString()
+}
+
+/**
+ * Column-majored dense matrix.
+ * The entry values are stored in a single array of doubles with columns listed in sequence.
+ * For example, the following matrix
+ * {{{
+ * 1.0 2.0
+ * 3.0 4.0
+ * 5.0 6.0
+ * }}}
+ * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`.
+ *
+ * @param numRows number of rows
+ * @param numCols number of columns
+ * @param values matrix entries in column major
+ */
+class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix {
+
+ require(values.length == numRows * numCols)
+
+ override def toArray: Array[Double] = values
+
+ private[mllib] override def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values)
+}
+
+/**
+ * Factory methods for [[org.apache.spark.mllib.linalg.Matrix]].
+ */
+object Matrices {
+
+ /**
+ * Creates a column-majored dense matrix.
+ *
+ * @param numRows number of rows
+ * @param numCols number of columns
+ * @param values matrix entries in column major
+ */
+ def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = {
+ new DenseMatrix(numRows, numCols, values)
+ }
+
+ /**
+ * Creates a Matrix instance from a breeze matrix.
+ * @param breeze a breeze matrix
+ * @return a Matrix instance
+ */
+ private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = {
+ breeze match {
+ case dm: BDM[Double] =>
+ require(dm.majorStride == dm.rows,
+ "Do not support stride size different from the number of rows.")
+ new DenseMatrix(dm.rows, dm.cols, dm.data)
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Do not support conversion from type ${breeze.getClass.getName}.")
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixSVD.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixSVD.scala
deleted file mode 100644
index 319f82b449..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixSVD.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.linalg
-
-/**
- * Class that represents the SV decomposition of a matrix
- *
- * @param U such that A = USV^T
- * @param S such that A = USV^T
- * @param V such that A = USV^T
- */
-case class MatrixSVD(val U: SparseMatrix,
- val S: SparseMatrix,
- val V: SparseMatrix)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/PCA.scala
deleted file mode 100644
index fe5b3f6c7e..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/PCA.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.linalg
-
-import org.apache.spark.rdd.RDD
-
-
-import org.jblas.DoubleMatrix
-
-
-/**
- * Class used to obtain principal components
- */
-class PCA {
- private var k = 1
-
- /**
- * Set the number of top-k principle components to return
- */
- def setK(k: Int): PCA = {
- this.k = k
- this
- }
-
- /**
- * Compute PCA using the current set parameters
- */
- def compute(matrix: TallSkinnyDenseMatrix): Array[Array[Double]] = {
- computePCA(matrix)
- }
-
- /**
- * Compute PCA using the parameters currently set
- * See computePCA() for more details
- */
- def compute(matrix: RDD[Array[Double]]): Array[Array[Double]] = {
- computePCA(matrix)
- }
-
- /**
- * Computes the top k principal component coefficients for the m-by-n data matrix X.
- * Rows of X correspond to observations and columns correspond to variables.
- * The coefficient matrix is n-by-k. Each column of coeff contains coefficients
- * for one principal component, and the columns are in descending
- * order of component variance.
- * This function centers the data and uses the
- * singular value decomposition (SVD) algorithm.
- *
- * @param matrix dense matrix to perform PCA on
- * @return An nxk matrix with principal components in columns. Columns are inner arrays
- */
- private def computePCA(matrix: TallSkinnyDenseMatrix): Array[Array[Double]] = {
- val m = matrix.m
- val n = matrix.n
-
- if (m <= 0 || n <= 0) {
- throw new IllegalArgumentException("Expecting a well-formed matrix: m=$m n=$n")
- }
-
- computePCA(matrix.rows.map(_.data))
- }
-
- /**
- * Computes the top k principal component coefficients for the m-by-n data matrix X.
- * Rows of X correspond to observations and columns correspond to variables.
- * The coefficient matrix is n-by-k. Each column of coeff contains coefficients
- * for one principal component, and the columns are in descending
- * order of component variance.
- * This function centers the data and uses the
- * singular value decomposition (SVD) algorithm.
- *
- * @param matrix dense matrix to perform pca on
- * @return An nxk matrix of principal components
- */
- private def computePCA(matrix: RDD[Array[Double]]): Array[Array[Double]] = {
- val n = matrix.first.size
-
- // compute column sums and normalize matrix
- val colSumsTemp = matrix.map((_, 1)).fold((Array.ofDim[Double](n), 0)) {
- (a, b) =>
- val am = new DoubleMatrix(a._1)
- val bm = new DoubleMatrix(b._1)
- am.addi(bm)
- (a._1, a._2 + b._2)
- }
-
- val m = colSumsTemp._2
- val colSums = colSumsTemp._1.map(x => x / m)
-
- val data = matrix.map {
- x =>
- val row = Array.ofDim[Double](n)
- var i = 0
- while (i < n) {
- row(i) = x(i) - colSums(i)
- i += 1
- }
- row
- }
-
- val (u, s, v) = new SVD().setK(k).compute(data)
- v
- }
-}
-
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
deleted file mode 100644
index 0d97b7d92f..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
+++ /dev/null
@@ -1,395 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.linalg
-
-import org.apache.spark.SparkContext
-import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.RDD
-
-import org.jblas.{DoubleMatrix, Singular, MatrixFunctions}
-
-/**
- * Class used to obtain singular value decompositions
- */
-class SVD {
- private var k = 1
- private var computeU = true
-
- // All singular values smaller than rCond * sigma(0)
- // are treated as zero, where sigma(0) is the largest singular value.
- private var rCond = 1e-9
-
- /**
- * Set the number of top-k singular vectors to return
- */
- def setK(k: Int): SVD = {
- this.k = k
- this
- }
-
- /**
- * Sets the reciprocal condition number (rCond). All singular values
- * smaller than rCond * sigma(0) are treated as zero,
- * where sigma(0) is the largest singular value.
- */
- def setReciprocalConditionNumber(smallS: Double): SVD = {
- this.rCond = smallS
- this
- }
-
- /**
- * Should U be computed?
- */
- def setComputeU(compU: Boolean): SVD = {
- this.computeU = compU
- this
- }
-
- /**
- * Compute SVD using the current set parameters
- */
- def compute(matrix: TallSkinnyDenseMatrix): TallSkinnyMatrixSVD = {
- denseSVD(matrix)
- }
-
- /**
- * Compute SVD using the current set parameters
- * Returns (U, S, V) such that A = USV^T
- * U is a row-by-row dense matrix
- * S is a simple double array of singular values
- * V is a 2d array matrix
- * See [[denseSVD]] for more documentation
- */
- def compute(matrix: RDD[Array[Double]]):
- (RDD[Array[Double]], Array[Double], Array[Array[Double]]) = {
- denseSVD(matrix)
- }
-
- /**
- * See full paramter definition of sparseSVD for more description.
- *
- * @param matrix sparse matrix to factorize
- * @return Three sparse matrices: U, S, V such that A = USV^T
- */
- def compute(matrix: SparseMatrix): MatrixSVD = {
- sparseSVD(matrix)
- }
-
- /**
- * Singular Value Decomposition for Tall and Skinny matrices.
- * Given an m x n matrix A, this will compute matrices U, S, V such that
- * A = U * S * V'
- *
- * There is no restriction on m, but we require n^2 doubles to fit in memory.
- * Further, n should be less than m.
- *
- * The decomposition is computed by first computing A'A = V S^2 V',
- * computing svd locally on that (since n x n is small),
- * from which we recover S and V.
- * Then we compute U via easy matrix multiplication
- * as U = A * V * S^-1
- *
- * Only the k largest singular values and associated vectors are found.
- * If there are k such values, then the dimensions of the return will be:
- *
- * S is k x k and diagonal, holding the singular values on diagonal
- * U is m x k and satisfies U'U = eye(k)
- * V is n x k and satisfies V'V = eye(k)
- *
- * @param matrix dense matrix to factorize
- * @return See [[TallSkinnyMatrixSVD]] for the output matrices and arrays
- */
- private def denseSVD(matrix: TallSkinnyDenseMatrix): TallSkinnyMatrixSVD = {
- val m = matrix.m
- val n = matrix.n
-
- if (m < n || m <= 0 || n <= 0) {
- throw new IllegalArgumentException("Expecting a tall and skinny matrix m=$m n=$n")
- }
-
- if (k < 1 || k > n) {
- throw new IllegalArgumentException("Request up to n singular values n=$n k=$k")
- }
-
- val rowIndices = matrix.rows.map(_.i)
-
- // compute SVD
- val (u, sigma, v) = denseSVD(matrix.rows.map(_.data))
-
- if (computeU) {
- // prep u for returning
- val retU = TallSkinnyDenseMatrix(
- u.zip(rowIndices).map {
- case (row, i) => MatrixRow(i, row)
- },
- m,
- k)
-
- TallSkinnyMatrixSVD(retU, sigma, v)
- } else {
- TallSkinnyMatrixSVD(null, sigma, v)
- }
- }
-
- /**
- * Singular Value Decomposition for Tall and Skinny matrices.
- * Given an m x n matrix A, this will compute matrices U, S, V such that
- * A = U * S * V'
- *
- * There is no restriction on m, but we require n^2 doubles to fit in memory.
- * Further, n should be less than m.
- *
- * The decomposition is computed by first computing A'A = V S^2 V',
- * computing svd locally on that (since n x n is small),
- * from which we recover S and V.
- * Then we compute U via easy matrix multiplication
- * as U = A * V * S^-1
- *
- * Only the k largest singular values and associated vectors are found.
- * If there are k such values, then the dimensions of the return will be:
- *
- * S is k x k and diagonal, holding the singular values on diagonal
- * U is m x k and satisfies U'U = eye(k)
- * V is n x k and satisfies V'V = eye(k)
- *
- * The return values are as lean as possible: an RDD of rows for U,
- * a simple array for sigma, and a dense 2d matrix array for V
- *
- * @param matrix dense matrix to factorize
- * @return Three matrices: U, S, V such that A = USV^T
- */
- private def denseSVD(matrix: RDD[Array[Double]]):
- (RDD[Array[Double]], Array[Double], Array[Array[Double]]) = {
- val n = matrix.first.size
-
- if (k < 1 || k > n) {
- throw new IllegalArgumentException(
- "Request up to n singular values k=$k n=$n")
- }
-
- // Compute A^T A
- val fullata = matrix.mapPartitions {
- iter =>
- val localATA = Array.ofDim[Double](n, n)
- while (iter.hasNext) {
- val row = iter.next()
- var i = 0
- while (i < n) {
- var j = 0
- while (j < n) {
- localATA(i)(j) += row(i) * row(j)
- j += 1
- }
- i += 1
- }
- }
- Iterator(localATA)
- }.fold(Array.ofDim[Double](n, n)) {
- (a, b) =>
- var i = 0
- while (i < n) {
- var j = 0
- while (j < n) {
- a(i)(j) += b(i)(j)
- j += 1
- }
- i += 1
- }
- a
- }
-
- // Construct jblas A^T A locally
- val ata = new DoubleMatrix(fullata)
-
- // Since A^T A is small, we can compute its SVD directly
- val svd = Singular.sparseSVD(ata)
- val V = svd(0)
- val sigmas = MatrixFunctions.sqrt(svd(1)).toArray.filter(x => x / svd(1).get(0) > rCond)
-
- val sk = Math.min(k, sigmas.size)
- val sigma = sigmas.take(sk)
-
- // prepare V for returning
- val retV = Array.tabulate(n, sk)((i, j) => V.get(i, j))
-
- if (computeU) {
- // Compute U as U = A V S^-1
- // Compute VS^-1
- val vsinv = new DoubleMatrix(Array.tabulate(n, sk)((i, j) => V.get(i, j) / sigma(j)))
- val retU = matrix.map {
- x =>
- val v = new DoubleMatrix(Array(x))
- v.mmul(vsinv).data
- }
- (retU, sigma, retV)
- } else {
- (null, sigma, retV)
- }
- }
-
- /**
- * Singular Value Decomposition for Tall and Skinny sparse matrices.
- * Given an m x n matrix A, this will compute matrices U, S, V such that
- * A = U * S * V'
- *
- * There is no restriction on m, but we require O(n^2) doubles to fit in memory.
- * Further, n should be less than m.
- *
- * The decomposition is computed by first computing A'A = V S^2 V',
- * computing svd locally on that (since n x n is small),
- * from which we recover S and V.
- * Then we compute U via easy matrix multiplication
- * as U = A * V * S^-1
- *
- * Only the k largest singular values and associated vectors are found.
- * If there are k such values, then the dimensions of the return will be:
- *
- * S is k x k and diagonal, holding the singular values on diagonal
- * U is m x k and satisfies U'U = eye(k)
- * V is n x k and satisfies V'V = eye(k)
- *
- * All input and output is expected in sparse matrix format, 0-indexed
- * as tuples of the form ((i,j),value) all in RDDs using the
- * SparseMatrix class
- *
- * @param matrix sparse matrix to factorize
- * @return Three sparse matrices: U, S, V such that A = USV^T
- */
- private def sparseSVD(matrix: SparseMatrix): MatrixSVD = {
- val data = matrix.data
- val m = matrix.m
- val n = matrix.n
-
- if (m < n || m <= 0 || n <= 0) {
- throw new IllegalArgumentException("Expecting a tall and skinny matrix")
- }
-
- if (k < 1 || k > n) {
- throw new IllegalArgumentException("Must request up to n singular values")
- }
-
- // Compute A^T A, assuming rows are sparse enough to fit in memory
- val rows = data.map(entry =>
- (entry.i, (entry.j, entry.mval))).groupByKey()
- val emits = rows.flatMap {
- case (rowind, cols) =>
- cols.flatMap {
- case (colind1, mval1) =>
- cols.map {
- case (colind2, mval2) =>
- ((colind1, colind2), mval1 * mval2)
- }
- }
- }.reduceByKey(_ + _)
-
- // Construct jblas A^T A locally
- val ata = DoubleMatrix.zeros(n, n)
- for (entry <- emits.collect()) {
- ata.put(entry._1._1, entry._1._2, entry._2)
- }
-
- // Since A^T A is small, we can compute its SVD directly
- val svd = Singular.sparseSVD(ata)
- val V = svd(0)
- // This will be updated to rcond
- val sigmas = MatrixFunctions.sqrt(svd(1)).toArray.filter(x => x > 1e-9)
-
- if (sigmas.size < k) {
- throw new Exception("Not enough singular values to return k=" + k + " s=" + sigmas.size)
- }
-
- val sigma = sigmas.take(k)
-
- val sc = data.sparkContext
-
- // prepare V for returning
- val retVdata = sc.makeRDD(
- Array.tabulate(V.rows, sigma.length) {
- (i, j) =>
- MatrixEntry(i, j, V.get(i, j))
- }.flatten)
- val retV = SparseMatrix(retVdata, V.rows, sigma.length)
-
- val retSdata = sc.makeRDD(Array.tabulate(sigma.length) {
- x => MatrixEntry(x, x, sigma(x))
- })
-
- val retS = SparseMatrix(retSdata, sigma.length, sigma.length)
-
- // Compute U as U = A V S^-1
- // turn V S^-1 into an RDD as a sparse matrix
- val vsirdd = sc.makeRDD(Array.tabulate(V.rows, sigma.length) {
- (i, j) => ((i, j), V.get(i, j) / sigma(j))
- }.flatten)
-
- if (computeU) {
- // Multiply A by VS^-1
- val aCols = data.map(entry => (entry.j, (entry.i, entry.mval)))
- val bRows = vsirdd.map(entry => (entry._1._1, (entry._1._2, entry._2)))
- val retUdata = aCols.join(bRows).map {
- case (key, ((rowInd, rowVal), (colInd, colVal))) =>
- ((rowInd, colInd), rowVal * colVal)
- }.reduceByKey(_ + _).map {
- case ((row, col), mval) => MatrixEntry(row, col, mval)
- }
-
- val retU = SparseMatrix(retUdata, m, sigma.length)
- MatrixSVD(retU, retS, retV)
- } else {
- MatrixSVD(null, retS, retV)
- }
- }
-}
-
-/**
- * Top-level methods for calling sparse Singular Value Decomposition
- * NOTE: All matrices are 0-indexed
- */
-object SVD {
- def main(args: Array[String]) {
- if (args.length < 8) {
- println("Usage: SVD <master> <matrix_file> <m> <n> " +
- "<k> <output_U_file> <output_S_file> <output_V_file>")
- System.exit(1)
- }
-
- val (master, inputFile, m, n, k, output_u, output_s, output_v) =
- (args(0), args(1), args(2).toInt, args(3).toInt,
- args(4).toInt, args(5), args(6), args(7))
-
- val sc = new SparkContext(master, "SVD")
-
- val rawData = sc.textFile(inputFile)
- val data = rawData.map {
- line =>
- val parts = line.split(',')
- MatrixEntry(parts(0).toInt, parts(1).toInt, parts(2).toDouble)
- }
-
- val decomposed = new SVD().setK(k).compute(SparseMatrix(data, m, n))
- val u = decomposed.U.data
- val s = decomposed.S.data
- val v = decomposed.V.data
-
- println("Computed " + s.collect().length + " singular values and vectors")
- u.saveAsTextFile(output_u)
- s.saveAsTextFile(output_s)
- v.saveAsTextFile(output_v)
- System.exit(0)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixRow.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
index 2608a67bfe..46b1054574 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixRow.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
@@ -17,10 +17,5 @@
package org.apache.spark.mllib.linalg
-/**
- * Class that represents a row of a dense matrix
- *
- * @param i row index (0 indexing used)
- * @param data entries of the row
- */
-case class MatrixRow(val i: Int, val data: Array[Double])
+/** Represents singular value decomposition (SVD) factors. */
+case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SparseMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SparseMatrix.scala
deleted file mode 100644
index cbd1a2a5a4..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SparseMatrix.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.linalg
-
-import org.apache.spark.rdd.RDD
-
-
-/**
- * Class that represents a sparse matrix
- *
- * @param data RDD of nonzero entries
- * @param m number of rows
- * @param n numner of columns
- */
-case class SparseMatrix(val data: RDD[MatrixEntry], val m: Int, val n: Int)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyDenseMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyDenseMatrix.scala
deleted file mode 100644
index e4ef3c58e8..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyDenseMatrix.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.linalg
-
-import org.apache.spark.rdd.RDD
-
-
-/**
- * Class that represents a dense matrix
- *
- * @param rows RDD of rows
- * @param m number of rows
- * @param n number of columns
- */
-case class TallSkinnyDenseMatrix(val rows: RDD[MatrixRow], val m: Int, val n: Int)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyMatrixSVD.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyMatrixSVD.scala
deleted file mode 100644
index b3a450e923..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyMatrixSVD.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.linalg
-
-/**
- * Class that represents the singular value decomposition of a matrix
- *
- * @param U such that A = USV^T is a TallSkinnyDenseMatrix
- * @param S such that A = USV^T is a simple double array
- * @param V such that A = USV^T, V is a 2d array matrix that holds
- * singular vectors in columns. Columns are inner arrays
- * i.e. V(i)(j) is standard math notation V_{ij}
- */
-case class TallSkinnyMatrixSVD(val U: TallSkinnyDenseMatrix,
- val S: Array[Double],
- val V: Array[Array[Double]])
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
new file mode 100644
index 0000000000..9194f65749
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg.distributed
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.linalg.Vectors
+
+/**
+ * Represents an entry in an distributed matrix.
+ * @param i row index
+ * @param j column index
+ * @param value value of the entry
+ */
+case class MatrixEntry(i: Long, j: Long, value: Double)
+
+/**
+ * Represents a matrix in coordinate format.
+ *
+ * @param entries matrix entries
+ * @param nRows number of rows. A non-positive value means unknown, and then the number of rows will
+ * be determined by the max row index plus one.
+ * @param nCols number of columns. A non-positive value means unknown, and then the number of
+ * columns will be determined by the max column index plus one.
+ */
+class CoordinateMatrix(
+ val entries: RDD[MatrixEntry],
+ private var nRows: Long,
+ private var nCols: Long) extends DistributedMatrix {
+
+ /** Alternative constructor leaving matrix dimensions to be determined automatically. */
+ def this(entries: RDD[MatrixEntry]) = this(entries, 0L, 0L)
+
+ /** Gets or computes the number of columns. */
+ override def numCols(): Long = {
+ if (nCols <= 0L) {
+ computeSize()
+ }
+ nCols
+ }
+
+ /** Gets or computes the number of rows. */
+ override def numRows(): Long = {
+ if (nRows <= 0L) {
+ computeSize()
+ }
+ nRows
+ }
+
+ /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */
+ def toIndexedRowMatrix(): IndexedRowMatrix = {
+ val nl = numCols()
+ if (nl > Int.MaxValue) {
+ sys.error(s"Cannot convert to a row-oriented format because the number of columns $nl is " +
+ "too large.")
+ }
+ val n = nl.toInt
+ val indexedRows = entries.map(entry => (entry.i, (entry.j.toInt, entry.value)))
+ .groupByKey()
+ .map { case (i, vectorEntries) =>
+ IndexedRow(i, Vectors.sparse(n, vectorEntries.toSeq))
+ }
+ new IndexedRowMatrix(indexedRows, numRows(), n)
+ }
+
+ /**
+ * Converts to RowMatrix, dropping row indices after grouping by row index.
+ * The number of columns must be within the integer range.
+ */
+ def toRowMatrix(): RowMatrix = {
+ toIndexedRowMatrix().toRowMatrix()
+ }
+
+ /** Determines the size by computing the max row/column index. */
+ private def computeSize() {
+ // Reduce will throw an exception if `entries` is empty.
+ val (m1, n1) = entries.map(entry => (entry.i, entry.j)).reduce { case ((i1, j1), (i2, j2)) =>
+ (math.max(i1, i2), math.max(j1, j2))
+ }
+ // There may be empty columns at the very right and empty rows at the very bottom.
+ nRows = math.max(nRows, m1 + 1L)
+ nCols = math.max(nCols, n1 + 1L)
+ }
+
+ /** Collects data and assembles a local matrix. */
+ private[mllib] override def toBreeze(): BDM[Double] = {
+ val m = numRows().toInt
+ val n = numCols().toInt
+ val mat = BDM.zeros[Double](m, n)
+ entries.collect().foreach { case MatrixEntry(i, j, value) =>
+ mat(i.toInt, j.toInt) = value
+ }
+ mat
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixEntry.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala
index 416996fcbe..13f72a3c72 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixEntry.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala
@@ -15,13 +15,23 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.linalg
+package org.apache.spark.mllib.linalg.distributed
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.mllib.linalg.Matrix
/**
- * Class that represents an entry in a sparse matrix of doubles.
- *
- * @param i row index (0 indexing used)
- * @param j column index (0 indexing used)
- * @param mval value of entry in matrix
+ * Represents a distributively stored matrix backed by one or more RDDs.
*/
-case class MatrixEntry(val i: Int, val j: Int, val mval: Double)
+trait DistributedMatrix extends Serializable {
+
+ /** Gets or computes the number of rows. */
+ def numRows(): Long
+
+ /** Gets or computes the number of columns. */
+ def numCols(): Long
+
+ /** Collects data and assembles a local dense breeze matrix (for test only). */
+ private[mllib] def toBreeze(): BDM[Double]
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
new file mode 100644
index 0000000000..e110f070bd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg.distributed
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.linalg.SingularValueDecomposition
+
+/** Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. */
+case class IndexedRow(index: Long, vector: Vector)
+
+/**
+ * Represents a row-oriented [[org.apache.spark.mllib.linalg.distributed.DistributedMatrix]] with
+ * indexed rows.
+ *
+ * @param rows indexed rows of this matrix
+ * @param nRows number of rows. A non-positive value means unknown, and then the number of rows will
+ * be determined by the max row index plus one.
+ * @param nCols number of columns. A non-positive value means unknown, and then the number of
+ * columns will be determined by the size of the first row.
+ */
+class IndexedRowMatrix(
+ val rows: RDD[IndexedRow],
+ private var nRows: Long,
+ private var nCols: Int) extends DistributedMatrix {
+
+ /** Alternative constructor leaving matrix dimensions to be determined automatically. */
+ def this(rows: RDD[IndexedRow]) = this(rows, 0L, 0)
+
+ override def numCols(): Long = {
+ if (nCols <= 0) {
+ // Calling `first` will throw an exception if `rows` is empty.
+ nCols = rows.first().vector.size
+ }
+ nCols
+ }
+
+ override def numRows(): Long = {
+ if (nRows <= 0L) {
+ // Reduce will throw an exception if `rows` is empty.
+ nRows = rows.map(_.index).reduce(math.max) + 1L
+ }
+ nRows
+ }
+
+ /**
+ * Drops row indices and converts this matrix to a
+ * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]].
+ */
+ def toRowMatrix(): RowMatrix = {
+ new RowMatrix(rows.map(_.vector), 0L, nCols)
+ }
+
+ /**
+ * Computes the singular value decomposition of this matrix.
+ * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'.
+ *
+ * There is no restriction on m, but we require `n^2` doubles to fit in memory.
+ * Further, n should be less than m.
+
+ * The decomposition is computed by first computing A'A = V S^2 V',
+ * computing svd locally on that (since n x n is small), from which we recover S and V.
+ * Then we compute U via easy matrix multiplication as U = A * (V * S^-1).
+ * Note that this approach requires `O(n^3)` time on the master node.
+ *
+ * At most k largest non-zero singular values and associated vectors are returned.
+ * If there are k such values, then the dimensions of the return will be:
+ *
+ * U is an [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]] of size m x k that
+ * satisfies U'U = eye(k),
+ * s is a Vector of size k, holding the singular values in descending order,
+ * and V is a local Matrix of size n x k that satisfies V'V = eye(k).
+ *
+ * @param k number of singular values to keep. We might return less than k if there are
+ * numerically zero singular values. See rCond.
+ * @param computeU whether to compute U
+ * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0)
+ * are treated as zero, where sigma(0) is the largest singular value.
+ * @return SingularValueDecomposition(U, s, V)
+ */
+ def computeSVD(
+ k: Int,
+ computeU: Boolean = false,
+ rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = {
+ val indices = rows.map(_.index)
+ val svd = toRowMatrix().computeSVD(k, computeU, rCond)
+ val U = if (computeU) {
+ val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
+ IndexedRow(i, v)
+ }
+ new IndexedRowMatrix(indexedRows, nRows, nCols)
+ } else {
+ null
+ }
+ SingularValueDecomposition(U, svd.s, svd.V)
+ }
+
+ /**
+ * Multiply this matrix by a local matrix on the right.
+ *
+ * @param B a local matrix whose number of rows must match the number of columns of this matrix
+ * @return an IndexedRowMatrix representing the product, which preserves partitioning
+ */
+ def multiply(B: Matrix): IndexedRowMatrix = {
+ val mat = toRowMatrix().multiply(B)
+ val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) =>
+ IndexedRow(i, v)
+ }
+ new IndexedRowMatrix(indexedRows, nRows, nCols)
+ }
+
+ /**
+ * Computes the Gramian matrix `A^T A`.
+ */
+ def computeGramianMatrix(): Matrix = {
+ toRowMatrix().computeGramianMatrix()
+ }
+
+ private[mllib] override def toBreeze(): BDM[Double] = {
+ val m = numRows().toInt
+ val n = numCols().toInt
+ val mat = BDM.zeros[Double](m, n)
+ rows.collect().foreach { case IndexedRow(rowIndex, vector) =>
+ val i = rowIndex.toInt
+ vector.toBreeze.activeIterator.foreach { case (j, v) =>
+ mat(i, j) = v
+ }
+ }
+ mat
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
new file mode 100644
index 0000000000..f59811f18a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg.distributed
+
+import java.util
+
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
+import breeze.numerics.{sqrt => brzSqrt}
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.Logging
+
+/**
+ * Represents a row-oriented distributed Matrix with no meaningful row indices.
+ *
+ * @param rows rows stored as an RDD[Vector]
+ * @param nRows number of rows. A non-positive value means unknown, and then the number of rows will
+ * be determined by the number of records in the RDD `rows`.
+ * @param nCols number of columns. A non-positive value means unknown, and then the number of
+ * columns will be determined by the size of the first row.
+ */
+class RowMatrix(
+ val rows: RDD[Vector],
+ private var nRows: Long,
+ private var nCols: Int) extends DistributedMatrix with Logging {
+
+ /** Alternative constructor leaving matrix dimensions to be determined automatically. */
+ def this(rows: RDD[Vector]) = this(rows, 0L, 0)
+
+ /** Gets or computes the number of columns. */
+ override def numCols(): Long = {
+ if (nCols <= 0) {
+ // Calling `first` will throw an exception if `rows` is empty.
+ nCols = rows.first().size
+ }
+ nCols
+ }
+
+ /** Gets or computes the number of rows. */
+ override def numRows(): Long = {
+ if (nRows <= 0L) {
+ nRows = rows.count()
+ if (nRows == 0L) {
+ sys.error("Cannot determine the number of rows because it is not specified in the " +
+ "constructor and the rows RDD is empty.")
+ }
+ }
+ nRows
+ }
+
+ /**
+ * Computes the Gramian matrix `A^T A`.
+ */
+ def computeGramianMatrix(): Matrix = {
+ val n = numCols().toInt
+ val nt: Int = n * (n + 1) / 2
+
+ // Compute the upper triangular part of the gram matrix.
+ val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))(
+ seqOp = (U, v) => {
+ RowMatrix.dspr(1.0, v, U.data)
+ U
+ },
+ combOp = (U1, U2) => U1 += U2
+ )
+
+ RowMatrix.triuToFull(n, GU.data)
+ }
+
+ /**
+ * Computes the singular value decomposition of this matrix.
+ * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'.
+ *
+ * There is no restriction on m, but we require `n^2` doubles to fit in memory.
+ * Further, n should be less than m.
+
+ * The decomposition is computed by first computing A'A = V S^2 V',
+ * computing svd locally on that (since n x n is small), from which we recover S and V.
+ * Then we compute U via easy matrix multiplication as U = A * (V * S^-1).
+ * Note that this approach requires `O(n^3)` time on the master node.
+ *
+ * At most k largest non-zero singular values and associated vectors are returned.
+ * If there are k such values, then the dimensions of the return will be:
+ *
+ * U is a RowMatrix of size m x k that satisfies U'U = eye(k),
+ * s is a Vector of size k, holding the singular values in descending order,
+ * and V is a Matrix of size n x k that satisfies V'V = eye(k).
+ *
+ * @param k number of singular values to keep. We might return less than k if there are
+ * numerically zero singular values. See rCond.
+ * @param computeU whether to compute U
+ * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0)
+ * are treated as zero, where sigma(0) is the largest singular value.
+ * @return SingularValueDecomposition(U, s, V)
+ */
+ def computeSVD(
+ k: Int,
+ computeU: Boolean = false,
+ rCond: Double = 1e-9): SingularValueDecomposition[RowMatrix, Matrix] = {
+ val n = numCols().toInt
+ require(k > 0 && k <= n, s"Request up to n singular values k=$k n=$n.")
+
+ val G = computeGramianMatrix()
+
+ // TODO: Use sparse SVD instead.
+ val (u: BDM[Double], sigmaSquares: BDV[Double], v: BDM[Double]) =
+ brzSvd(G.toBreeze.asInstanceOf[BDM[Double]])
+ val sigmas: BDV[Double] = brzSqrt(sigmaSquares)
+
+ // Determine effective rank.
+ val sigma0 = sigmas(0)
+ val threshold = rCond * sigma0
+ var i = 0
+ while (i < k && sigmas(i) >= threshold) {
+ i += 1
+ }
+ val sk = i
+
+ if (sk < k) {
+ logWarning(s"Requested $k singular values but only found $sk nonzeros.")
+ }
+
+ val s = Vectors.dense(util.Arrays.copyOfRange(sigmas.data, 0, sk))
+ val V = Matrices.dense(n, sk, util.Arrays.copyOfRange(u.data, 0, n * sk))
+
+ if (computeU) {
+ // N = Vk * Sk^{-1}
+ val N = new BDM[Double](n, sk, util.Arrays.copyOfRange(u.data, 0, n * sk))
+ var i = 0
+ var j = 0
+ while (j < sk) {
+ i = 0
+ val sigma = sigmas(j)
+ while (i < n) {
+ N(i, j) /= sigma
+ i += 1
+ }
+ j += 1
+ }
+ val U = this.multiply(Matrices.fromBreeze(N))
+ SingularValueDecomposition(U, s, V)
+ } else {
+ SingularValueDecomposition(null, s, V)
+ }
+ }
+
+ /**
+ * Computes the covariance matrix, treating each row as an observation.
+ * @return a local dense matrix of size n x n
+ */
+ def computeCovariance(): Matrix = {
+ val n = numCols().toInt
+
+ if (n > 10000) {
+ val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE
+ logWarning(s"The number of columns $n is greater than 10000! " +
+ s"We need at least $mem bytes of memory.")
+ }
+
+ val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
+ seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),
+ combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
+ )
+
+ // Update _m if it is not set, or verify its value.
+ if (nRows <= 0L) {
+ nRows = m
+ } else {
+ require(nRows == m,
+ s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
+ }
+
+ mean :/= m.toDouble
+
+ // We use the formula Cov(X, Y) = E[X * Y] - E[X] E[Y], which is not accurate if E[X * Y] is
+ // large but Cov(X, Y) is small, but it is good for sparse computation.
+ // TODO: find a fast and stable way for sparse data.
+
+ val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]]
+
+ var i = 0
+ var j = 0
+ val m1 = m - 1.0
+ var alpha = 0.0
+ while (i < n) {
+ alpha = m / m1 * mean(i)
+ j = 0
+ while (j < n) {
+ G(i, j) = G(i, j) / m1 - alpha * mean(j)
+ j += 1
+ }
+ i += 1
+ }
+
+ Matrices.fromBreeze(G)
+ }
+
+ /**
+ * Computes the top k principal components.
+ * Rows correspond to observations and columns correspond to variables.
+ * The principal components are stored a local matrix of size n-by-k.
+ * Each column corresponds for one principal component,
+ * and the columns are in descending order of component variance.
+ *
+ * @param k number of top principal components.
+ * @return a matrix of size n-by-k, whose columns are principal components
+ */
+ def computePrincipalComponents(k: Int): Matrix = {
+ val n = numCols().toInt
+ require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]")
+
+ val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]]
+
+ val (u: BDM[Double], _, _) = brzSvd(Cov)
+
+ if (k == n) {
+ Matrices.dense(n, k, u.data)
+ } else {
+ Matrices.dense(n, k, util.Arrays.copyOfRange(u.data, 0, n * k))
+ }
+ }
+
+ /**
+ * Multiply this matrix by a local matrix on the right.
+ *
+ * @param B a local matrix whose number of rows must match the number of columns of this matrix
+ * @return a [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] representing the product,
+ * which preserves partitioning
+ */
+ def multiply(B: Matrix): RowMatrix = {
+ val n = numCols().toInt
+ require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}")
+
+ require(B.isInstanceOf[DenseMatrix],
+ s"Only support dense matrix at this time but found ${B.getClass.getName}.")
+
+ val Bb = rows.context.broadcast(B)
+ val AB = rows.mapPartitions({ iter =>
+ val Bi = Bb.value.toBreeze.asInstanceOf[BDM[Double]]
+ iter.map(v => Vectors.fromBreeze(Bi.t * v.toBreeze))
+ }, preservesPartitioning = true)
+
+ new RowMatrix(AB, nRows, B.numCols)
+ }
+
+ private[mllib] override def toBreeze(): BDM[Double] = {
+ val m = numRows().toInt
+ val n = numCols().toInt
+ val mat = BDM.zeros[Double](m, n)
+ var i = 0
+ rows.collect().foreach { v =>
+ v.toBreeze.activeIterator.foreach { case (j, v) =>
+ mat(i, j) = v
+ }
+ i += 1
+ }
+ mat
+ }
+}
+
+object RowMatrix {
+
+ /**
+ * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR.
+ *
+ * @param U the upper triangular part of the matrix packed in an array (column major)
+ */
+ private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
+ // TODO: Find a better home (breeze?) for this method.
+ val n = v.size
+ v match {
+ case dv: DenseVector =>
+ blas.dspr("U", n, 1.0, dv.values, 1, U)
+ case sv: SparseVector =>
+ val indices = sv.indices
+ val values = sv.values
+ val nnz = indices.length
+ var colStartIdx = 0
+ var prevCol = 0
+ var col = 0
+ var j = 0
+ var i = 0
+ var av = 0.0
+ while (j < nnz) {
+ col = indices(j)
+ // Skip empty columns.
+ colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
+ col = indices(j)
+ av = alpha * values(j)
+ i = 0
+ while (i <= j) {
+ U(colStartIdx + indices(i)) += av * values(i)
+ i += 1
+ }
+ j += 1
+ prevCol = col
+ }
+ }
+ }
+
+ /**
+ * Fills a full square matrix from its upper triangular part.
+ */
+ private def triuToFull(n: Int, U: Array[Double]): Matrix = {
+ val G = new BDM[Double](n, n)
+
+ var row = 0
+ var col = 0
+ var idx = 0
+ var value = 0.0
+ while (col < n) {
+ row = 0
+ while (row < col) {
+ value = U(idx)
+ G(row, col) = value
+ G(col, row) = value
+ idx += 1
+ row += 1
+ }
+ G(col, col) = U(idx)
+ idx += 1
+ col +=1
+ }
+
+ Matrices.dense(n, n, G.data)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LAUtils.scala
deleted file mode 100644
index 87aac34757..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LAUtils.scala
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.util
-
-import org.apache.spark.SparkContext._
-
-import org.apache.spark.mllib.linalg._
-
-/**
- * Helper methods for linear algebra
- */
-object LAUtils {
- /**
- * Convert a SparseMatrix into a TallSkinnyDenseMatrix
- *
- * @param sp Sparse matrix to be converted
- * @return dense version of the input
- */
- def sparseToTallSkinnyDense(sp: SparseMatrix): TallSkinnyDenseMatrix = {
- val m = sp.m
- val n = sp.n
- val rows = sp.data.map(x => (x.i, (x.j, x.mval))).groupByKey().map {
- case (i, cols) =>
- val rowArray = Array.ofDim[Double](n)
- var j = 0
- val colsItr = cols.iterator
- while (colsItr.hasNext) {
- val element = colsItr.next
- rowArray(element._1) = element._2
- j += 1
- }
- MatrixRow(i, rowArray)
- }
- TallSkinnyDenseMatrix(rows, m, n)
- }
-
- /**
- * Convert a TallSkinnyDenseMatrix to a SparseMatrix
- *
- * @param a matrix to be converted
- * @return sparse version of the input
- */
- def denseToSparse(a: TallSkinnyDenseMatrix): SparseMatrix = {
- val m = a.m
- val n = a.n
- val data = a.rows.flatMap {
- mrow => Array.tabulate(n)(j => MatrixEntry(mrow.i, j, mrow.data(j)))
- .filter(x => x.mval != 0)
- }
- SparseMatrix(data, m, n)
- }
-}