aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala5
1 files changed, 3 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 0a0e0b0960..b96bc48566 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
@@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String)
setDefault(stopWords -> StopWords.English, caseSensitive -> false)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
val t = if ($(caseSensitive)) {
val stopWordsSet = $(stopWords).toSet