aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-22 21:56:07 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-22 21:56:07 -0800
commitd9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8 (patch)
treed49602f3ebe461236ea5c27294deeadea610dbd5
parentfc4b792d287095d70379a51f117c225d8d857078 (diff)
downloadspark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.tar.gz
spark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.tar.bz2
spark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.zip
[SPARK-11912][ML] ml.feature.PCA minor refactor
Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` is params and we should save it under ```metadata/``` rather than both under ```data/``` and ```metadata/```. Refactor the constructor of ```ml.feature.PCAModel``` to take only ```pc``` but construct ```mllib.feature.PCAModel``` inside ```transform```. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9897 from yanboliang/spark-11912.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala23
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala31
2 files changed, 24 insertions, 30 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 32d7afee6e..aa88cb03d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
val pca = new feature.PCA(k = $(k))
val pcaModel = pca.fit(input)
- copyValues(new PCAModel(uid, pcaModel).setParent(this))
+ copyValues(new PCAModel(uid, pcaModel.pc).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] {
/**
* :: Experimental ::
* Model fitted by [[PCA]].
+ *
+ * @param pc A principal components Matrix. Each column is one principal component.
*/
@Experimental
class PCAModel private[ml] (
override val uid: String,
- pcaModel: feature.PCAModel)
+ val pc: DenseMatrix)
extends Model[PCAModel] with PCAParams with MLWritable {
import PCAModel._
- /** a principal components Matrix. Each column is one principal component. */
- val pc: DenseMatrix = pcaModel.pc
-
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -124,6 +123,7 @@ class PCAModel private[ml] (
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
+ val pcaModel = new feature.PCAModel($(k), pc)
val pcaOp = udf { pcaModel.transform _ }
dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
}
@@ -139,7 +139,7 @@ class PCAModel private[ml] (
}
override def copy(extra: ParamMap): PCAModel = {
- val copied = new PCAModel(uid, pcaModel)
+ val copied = new PCAModel(uid, pc)
copyValues(copied, extra).setParent(parent)
}
@@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] {
private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
- private case class Data(k: Int, pc: DenseMatrix)
+ private case class Data(pc: DenseMatrix)
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
- val data = Data(instance.getK, instance.pc)
+ val data = Data(instance.pc)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
@@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
- .select("k", "pc")
+ val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
+ .select("pc")
.head()
- val oldModel = new feature.PCAModel(k, pc)
- val model = new PCAModel(metadata.uid, oldModel)
+ val model = new PCAModel(metadata.uid, pc)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 5a21cd20ce..edab21e6c3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
test("params") {
ParamsSuite.checkParams(new PCA)
val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
- val model = new PCAModel("pca", new OldPCAModel(2, mat))
+ val model = new PCAModel("pca", mat)
ParamsSuite.checkParams(model)
}
@@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
}
- test("read/write") {
+ test("PCA read/write") {
+ val t = new PCA()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setK(3)
+ testDefaultReadWrite(t)
+ }
- def checkModelData(model1: PCAModel, model2: PCAModel): Unit = {
- assert(model1.pc === model2.pc)
- }
- val allParams: Map[String, Any] = Map(
- "k" -> 3,
- "inputCol" -> "features",
- "outputCol" -> "pca_features"
- )
- val data = Seq(
- (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))),
- (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
- (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
- )
- val df = sqlContext.createDataFrame(data).toDF("id", "features")
- val pca = new PCA().setK(3)
- testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData)
+ test("PCAModel read/write") {
+ val instance = new PCAModel("myPCAModel",
+ Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix])
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.pc === instance.pc)
}
}