diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 7de5a4d5d3..e9df600c8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -66,7 +66,8 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): MaxAbsScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MaxAbsScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) @@ -111,7 +112,8 @@ class MaxAbsScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // TODO: this looks hack, we may have to handle sparse and dense vectors separately. val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) |