aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala6
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 26ee8e1bf1..118a6e3e6a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
- override def fit(dataset: DataFrame): StandardScalerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): StandardScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
@@ -135,7 +136,8 @@ class StandardScalerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
val scale = udf { scaler.transform _ }