aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-22 21:56:07 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-22 21:56:07 -0800
commitd9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8 (patch)
treed49602f3ebe461236ea5c27294deeadea610dbd5 /mllib/src/main/scala/org
parentfc4b792d287095d70379a51f117c225d8d857078 (diff)
downloadspark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.tar.gz
spark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.tar.bz2
spark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.zip
[SPARK-11912][ML] ml.feature.PCA minor refactor
Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` is params and we should save it under ```metadata/``` rather than both under ```data/``` and ```metadata/```. Refactor the constructor of ```ml.feature.PCAModel``` to take only ```pc``` but construct ```mllib.feature.PCAModel``` inside ```transform```. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9897 from yanboliang/spark-11912.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala23
1 files changed, 11 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 32d7afee6e..aa88cb03d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
val pca = new feature.PCA(k = $(k))
val pcaModel = pca.fit(input)
- copyValues(new PCAModel(uid, pcaModel).setParent(this))
+ copyValues(new PCAModel(uid, pcaModel.pc).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] {
/**
* :: Experimental ::
* Model fitted by [[PCA]].
+ *
+ * @param pc A principal components Matrix. Each column is one principal component.
*/
@Experimental
class PCAModel private[ml] (
override val uid: String,
- pcaModel: feature.PCAModel)
+ val pc: DenseMatrix)
extends Model[PCAModel] with PCAParams with MLWritable {
import PCAModel._
- /** a principal components Matrix. Each column is one principal component. */
- val pc: DenseMatrix = pcaModel.pc
-
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -124,6 +123,7 @@ class PCAModel private[ml] (
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
+ val pcaModel = new feature.PCAModel($(k), pc)
val pcaOp = udf { pcaModel.transform _ }
dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
}
@@ -139,7 +139,7 @@ class PCAModel private[ml] (
}
override def copy(extra: ParamMap): PCAModel = {
- val copied = new PCAModel(uid, pcaModel)
+ val copied = new PCAModel(uid, pc)
copyValues(copied, extra).setParent(parent)
}
@@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] {
private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
- private case class Data(k: Int, pc: DenseMatrix)
+ private case class Data(pc: DenseMatrix)
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
- val data = Data(instance.getK, instance.pc)
+ val data = Data(instance.pc)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
@@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
- .select("k", "pc")
+ val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
+ .select("pc")
.head()
- val oldModel = new feature.PCAModel(k, pc)
- val model = new PCAModel(metadata.uid, oldModel)
+ val model = new PCAModel(metadata.uid, pc)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}