diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 26ee8e1bf1..118a6e3e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - override def fit(dataset: DataFrame): StandardScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) @@ -135,7 +136,8 @@ class StandardScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } |