aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala6
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 305c3d187f..9cf722e121 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -68,7 +68,8 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
/**
* Computes a [[PCAModel]] that contains the principal components of the input vectors.
*/
- override def fit(dataset: DataFrame): PCAModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): PCAModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v}
val pca = new feature.PCA(k = $(k))
@@ -124,7 +125,8 @@ class PCAModel private[ml] (
* NOTE: Vectors to be transformed must be the same length
* as the source vectors given to [[PCA.fit()]].
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val pcaModel = new feature.PCAModel($(k), pc, explainedVariance)
val pcaOp = udf { pcaModel.transform _ }