diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index f36cf503a0..5075b78c98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -76,7 +76,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - override def fit(dataset: DataFrame): IDFModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IDFModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) @@ -115,7 +116,8 @@ class IDFModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val idf = udf { vec: Vector => idfModel.transform(vec) } dataset.withColumn($(outputCol), idf(col($(inputCol)))) |