diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0a0e0b0960..b96bc48566 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructType} @@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String) setDefault(stopWords -> StopWords.English, caseSensitive -> false) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet |