aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-22 21:56:07 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-22 21:56:07 -0800
commitd9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8 (patch)
treed49602f3ebe461236ea5c27294deeadea610dbd5 /mllib/src/test/scala/org/apache
parentfc4b792d287095d70379a51f117c225d8d857078 (diff)
downloadspark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.tar.gz
spark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.tar.bz2
spark-d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8.zip
[SPARK-11912][ML] ml.feature.PCA minor refactor
Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` is params and we should save it under ```metadata/``` rather than both under ```data/``` and ```metadata/```. Refactor the constructor of ```ml.feature.PCAModel``` to take only ```pc``` but construct ```mllib.feature.PCAModel``` inside ```transform```. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9897 from yanboliang/spark-11912.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala31
1 files changed, 13 insertions, 18 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 5a21cd20ce..edab21e6c3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
test("params") {
ParamsSuite.checkParams(new PCA)
val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
- val model = new PCAModel("pca", new OldPCAModel(2, mat))
+ val model = new PCAModel("pca", mat)
ParamsSuite.checkParams(model)
}
@@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
}
- test("read/write") {
+ test("PCA read/write") {
+ val t = new PCA()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setK(3)
+ testDefaultReadWrite(t)
+ }
- def checkModelData(model1: PCAModel, model2: PCAModel): Unit = {
- assert(model1.pc === model2.pc)
- }
- val allParams: Map[String, Any] = Map(
- "k" -> 3,
- "inputCol" -> "features",
- "outputCol" -> "pca_features"
- )
- val data = Seq(
- (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))),
- (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
- (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
- )
- val df = sqlContext.createDataFrame(data).toDF("id", "features")
- val pca = new PCA().setK(3)
- testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData)
+ test("PCAModel read/write") {
+ val instance = new PCAModel("myPCAModel",
+ Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix])
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.pc === instance.pc)
}
}