aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala6
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index b9e9d56853..cfecae7e0b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -77,7 +77,8 @@ final class ChiSqSelector(override val uid: String)
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def fit(dataset: DataFrame): ChiSqSelectorModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
case Row(label: Double, features: Vector) =>
@@ -127,7 +128,8 @@ final class ChiSqSelectorModel private[ml] (
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
val newField = transformedSchema.last
val selector = udf { chiSqSelector.transform _ }