diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala | 69 |
1 files changed, 66 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 1482eb3d1f..0c6a37bab0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature +import org.apache.spark.mllib.feature.ChiSqSelectorType import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.rdd.RDD @@ -54,11 +55,47 @@ private[feature] trait ChiSqSelectorParams extends Params /** @group getParam */ def getNumTopFeatures: Int = $(numTopFeatures) + + final val percentile = new DoubleParam(this, "percentile", + "Percentile of features that selector will select, ordered by statistics value descending.", + ParamValidators.inRange(0, 1)) + setDefault(percentile -> 0.1) + + /** @group getParam */ + def getPercentile: Double = $(percentile) + + final val alpha = new DoubleParam(this, "alpha", + "The highest p-value for features to be kept.", + ParamValidators.inRange(0, 1)) + setDefault(alpha -> 0.05) + + /** @group getParam */ + def getAlpha: Double = $(alpha) + + /** + * The ChiSqSelector supports KBest, Percentile, FPR selection, + * which is the same as ChiSqSelectorType defined in MLLIB. + * when call setNumTopFeatures, the selectorType is set to KBest + * when call setPercentile, the selectorType is set to Percentile + * when call setAlpha, the selectorType is set to FPR + */ + final val selectorType = new Param[String](this, "selectorType", + "ChiSqSelector Type: KBest, Percentile, FPR") + setDefault(selectorType -> ChiSqSelectorType.KBest.toString) + + /** @group getParam */ + def getChiSqSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. + * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. + * `KBest` chooses the `k` top features according to a chi-squared test. + * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. + * `FPR` chooses all features whose false positive rate meets some threshold. + * By default, the selection method is `KBest`, the default number of top features is 50. + * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -69,7 +106,22 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str /** @group setParam */ @Since("1.6.0") - def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) + def setNumTopFeatures(value: Int): this.type = { + set(selectorType, ChiSqSelectorType.KBest.toString) + set(numTopFeatures, value) + } + + @Since("2.1.0") + def setPercentile(value: Double): this.type = { + set(selectorType, ChiSqSelectorType.Percentile.toString) + set(percentile, value) + } + + @Since("2.1.0") + def setAlpha(value: Double): this.type = { + set(selectorType, ChiSqSelectorType.FPR.toString) + set(alpha, value) + } /** @group setParam */ @Since("1.6.0") @@ -91,8 +143,19 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str case Row(label: Double, features: Vector) => OldLabeledPoint(label, OldVectors.fromML(features)) } - val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input) - copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this)) + var selector = new feature.ChiSqSelector() + ChiSqSelectorType.withName($(selectorType)) match { + case ChiSqSelectorType.KBest => + selector.setNumTopFeatures($(numTopFeatures)) + case ChiSqSelectorType.Percentile => + selector.setPercentile($(percentile)) + case ChiSqSelectorType.FPR => + selector.setAlpha($(alpha)) + case errorType => + throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") + } + val model = selector.fit(input) + copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") |