diff options
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala | 93 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala | 48 |
2 files changed, 141 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala new file mode 100644 index 0000000000..4e01e402b4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.rdd.RDD + +/** + * A feature transformer that projects vectors to a low-dimensional space using PCA. + * + * @param k number of principal components + */ +class PCA(val k: Int) { + require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") + + /** + * Computes a [[PCAModel]] that contains the principal components of the input vectors. + * + * @param sources source vectors + */ + def fit(sources: RDD[Vector]): PCAModel = { + require(k <= sources.first().size, + s"source vector size is ${sources.first().size} must be greater than k=$k") + + val mat = new RowMatrix(sources) + val pc = mat.computePrincipalComponents(k) match { + case dm: DenseMatrix => + dm + case sm: SparseMatrix => + /* Convert a sparse matrix to dense. + * + * RowMatrix.computePrincipalComponents always returns a dense matrix. + * The following code is a safeguard. + */ + sm.toDense + case m => + throw new IllegalArgumentException("Unsupported matrix format. Expected " + + s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") + + } + new PCAModel(k, pc) + } + + /** Java-friendly version of [[fit()]] */ + def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd) +} + +/** + * Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + * + * @param k number of principal components. + * @param pc a principal components Matrix. Each column is one principal component. + */ +class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { + /** + * Transform a vector by computed Principal Components. + * + * @param vector vector to be transformed. + * Vector must be the same length as the source vectors given to [[PCA.fit()]]. + * @return transformed vector. Vector will be of length k. + */ + override def transform(vector: Vector): Vector = { + vector match { + case dv: DenseVector => + pc.transpose.multiply(dv) + case SparseVector(size, indices, values) => + /* SparseVector -> single row SparseMatrix */ + val sm = Matrices.sparse(size, 1, Array(0, indices.length), indices, values).transpose + val projection = sm.multiply(pc) + Vectors.dense(projection.values) + case _ => + throw new IllegalArgumentException("Unsupported vector format. Expected " + + s"SparseVector or DenseVector. Instead got: ${vector.getClass}") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala new file mode 100644 index 0000000000..758af588f1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class PCASuite extends FunSuite with MLlibTestSparkContext { + + private val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + + private lazy val dataRDD = sc.parallelize(data, 2) + + test("Correct computing use a PCA wrapper") { + val k = dataRDD.count().toInt + val pca = new PCA(k).fit(dataRDD) + + val mat = new RowMatrix(dataRDD) + val pc = mat.computePrincipalComponents(k) + + val pca_transform = pca.transform(dataRDD).collect() + val mat_multiply = mat.multiply(pc).rows.collect() + + assert(pca_transform.toSet === mat_multiply.toSet) + } +} |