aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala69
1 files changed, 66 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 1482eb3d1f..0c6a37bab0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -27,6 +27,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.feature.ChiSqSelectorType
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.rdd.RDD
@@ -54,11 +55,47 @@ private[feature] trait ChiSqSelectorParams extends Params
/** @group getParam */
def getNumTopFeatures: Int = $(numTopFeatures)
+
+ final val percentile = new DoubleParam(this, "percentile",
+ "Percentile of features that selector will select, ordered by statistics value descending.",
+ ParamValidators.inRange(0, 1))
+ setDefault(percentile -> 0.1)
+
+ /** @group getParam */
+ def getPercentile: Double = $(percentile)
+
+ final val alpha = new DoubleParam(this, "alpha",
+ "The highest p-value for features to be kept.",
+ ParamValidators.inRange(0, 1))
+ setDefault(alpha -> 0.05)
+
+ /** @group getParam */
+ def getAlpha: Double = $(alpha)
+
+ /**
+ * The ChiSqSelector supports KBest, Percentile, FPR selection,
+ * which is the same as ChiSqSelectorType defined in MLLIB.
+ * when call setNumTopFeatures, the selectorType is set to KBest
+ * when call setPercentile, the selectorType is set to Percentile
+ * when call setAlpha, the selectorType is set to FPR
+ */
+ final val selectorType = new Param[String](this, "selectorType",
+ "ChiSqSelector Type: KBest, Percentile, FPR")
+ setDefault(selectorType -> ChiSqSelectorType.KBest.toString)
+
+ /** @group getParam */
+ def getChiSqSelectorType: String = $(selectorType)
}
/**
* Chi-Squared feature selection, which selects categorical features to use for predicting a
* categorical label.
+ * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
+ * `KBest` chooses the `k` top features according to a chi-squared test.
+ * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `FPR` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `KBest`, the default number of top features is 50.
+ * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
*/
@Since("1.6.0")
final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@@ -69,7 +106,22 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
/** @group setParam */
@Since("1.6.0")
- def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
+ def setNumTopFeatures(value: Int): this.type = {
+ set(selectorType, ChiSqSelectorType.KBest.toString)
+ set(numTopFeatures, value)
+ }
+
+ @Since("2.1.0")
+ def setPercentile(value: Double): this.type = {
+ set(selectorType, ChiSqSelectorType.Percentile.toString)
+ set(percentile, value)
+ }
+
+ @Since("2.1.0")
+ def setAlpha(value: Double): this.type = {
+ set(selectorType, ChiSqSelectorType.FPR.toString)
+ set(alpha, value)
+ }
/** @group setParam */
@Since("1.6.0")
@@ -91,8 +143,19 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
case Row(label: Double, features: Vector) =>
OldLabeledPoint(label, OldVectors.fromML(features))
}
- val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input)
- copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this))
+ var selector = new feature.ChiSqSelector()
+ ChiSqSelectorType.withName($(selectorType)) match {
+ case ChiSqSelectorType.KBest =>
+ selector.setNumTopFeatures($(numTopFeatures))
+ case ChiSqSelectorType.Percentile =>
+ selector.setPercentile($(percentile))
+ case ChiSqSelectorType.FPR =>
+ selector.setAlpha($(alpha))
+ case errorType =>
+ throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
+ }
+ val model = selector.fit(input)
+ copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}
@Since("1.6.0")