aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala86
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala38
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala2
-rwxr-xr-xpython/pyspark/ml/feature.py71
-rw-r--r--python/pyspark/mllib/feature.py59
7 files changed, 206 insertions, 128 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 0c6a37bab0..9c131a4185 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.feature.ChiSqSelectorType
+import org.apache.spark.mllib.feature.{ChiSqSelector => OldChiSqSelector}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.rdd.RDD
@@ -44,7 +44,9 @@ private[feature] trait ChiSqSelectorParams extends Params
/**
* Number of features that selector will select (ordered by statistic value descending). If the
* number of features is less than numTopFeatures, then this will select all features.
+ * Only applicable when selectorType = "kbest".
* The default value of numTopFeatures is 50.
+ *
* @group param
*/
final val numTopFeatures = new IntParam(this, "numTopFeatures",
@@ -56,6 +58,11 @@ private[feature] trait ChiSqSelectorParams extends Params
/** @group getParam */
def getNumTopFeatures: Int = $(numTopFeatures)
+ /**
+ * Percentile of features that selector will select, ordered by statistics value descending.
+ * Only applicable when selectorType = "percentile".
+ * Default value is 0.1.
+ */
final val percentile = new DoubleParam(this, "percentile",
"Percentile of features that selector will select, ordered by statistics value descending.",
ParamValidators.inRange(0, 1))
@@ -64,8 +71,12 @@ private[feature] trait ChiSqSelectorParams extends Params
/** @group getParam */
def getPercentile: Double = $(percentile)
- final val alpha = new DoubleParam(this, "alpha",
- "The highest p-value for features to be kept.",
+ /**
+ * The highest p-value for features to be kept.
+ * Only applicable when selectorType = "fpr".
+ * Default value is 0.05.
+ */
+ final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.",
ParamValidators.inRange(0, 1))
setDefault(alpha -> 0.05)
@@ -73,29 +84,27 @@ private[feature] trait ChiSqSelectorParams extends Params
def getAlpha: Double = $(alpha)
/**
- * The ChiSqSelector supports KBest, Percentile, FPR selection,
- * which is the same as ChiSqSelectorType defined in MLLIB.
- * when call setNumTopFeatures, the selectorType is set to KBest
- * when call setPercentile, the selectorType is set to Percentile
- * when call setAlpha, the selectorType is set to FPR
+ * The selector type of the ChisqSelector.
+ * Supported options: "kbest" (default), "percentile" and "fpr".
*/
final val selectorType = new Param[String](this, "selectorType",
- "ChiSqSelector Type: KBest, Percentile, FPR")
- setDefault(selectorType -> ChiSqSelectorType.KBest.toString)
+ "The selector type of the ChisqSelector. " +
+ "Supported options: kbest (default), percentile and fpr.",
+ ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray))
+ setDefault(selectorType -> OldChiSqSelector.KBest)
/** @group getParam */
- def getChiSqSelectorType: String = $(selectorType)
+ def getSelectorType: String = $(selectorType)
}
/**
* Chi-Squared feature selection, which selects categorical features to use for predicting a
* categorical label.
- * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
- * `KBest` chooses the `k` top features according to a chi-squared test.
- * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `FPR` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `KBest`, the default number of top features is 50.
- * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+ * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
+ * `kbest` chooses the `k` top features according to a chi-squared test.
+ * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `fpr` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `kbest`, the default number of top features is 50.
*/
@Since("1.6.0")
final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@@ -105,23 +114,20 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
def this() = this(Identifiable.randomUID("chiSqSelector"))
/** @group setParam */
+ @Since("2.1.0")
+ def setSelectorType(value: String): this.type = set(selectorType, value)
+
+ /** @group setParam */
@Since("1.6.0")
- def setNumTopFeatures(value: Int): this.type = {
- set(selectorType, ChiSqSelectorType.KBest.toString)
- set(numTopFeatures, value)
- }
+ def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
+ /** @group setParam */
@Since("2.1.0")
- def setPercentile(value: Double): this.type = {
- set(selectorType, ChiSqSelectorType.Percentile.toString)
- set(percentile, value)
- }
+ def setPercentile(value: Double): this.type = set(percentile, value)
+ /** @group setParam */
@Since("2.1.0")
- def setAlpha(value: Double): this.type = {
- set(selectorType, ChiSqSelectorType.FPR.toString)
- set(alpha, value)
- }
+ def setAlpha(value: Double): this.type = set(alpha, value)
/** @group setParam */
@Since("1.6.0")
@@ -143,23 +149,23 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
case Row(label: Double, features: Vector) =>
OldLabeledPoint(label, OldVectors.fromML(features))
}
- var selector = new feature.ChiSqSelector()
- ChiSqSelectorType.withName($(selectorType)) match {
- case ChiSqSelectorType.KBest =>
- selector.setNumTopFeatures($(numTopFeatures))
- case ChiSqSelectorType.Percentile =>
- selector.setPercentile($(percentile))
- case ChiSqSelectorType.FPR =>
- selector.setAlpha($(alpha))
- case errorType =>
- throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
- }
+ val selector = new feature.ChiSqSelector()
+ .setSelectorType($(selectorType))
+ .setNumTopFeatures($(numTopFeatures))
+ .setPercentile($(percentile))
+ .setAlpha($(alpha))
val model = selector.fit(input)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
+ val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType))
+ otherPairs.foreach { case (_, paramName: String) =>
+ if (isSet(getParam(paramName))) {
+ logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
+ }
+ }
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkNumericType(schema, $(labelCol))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 5cffbf0892..904000f50d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -629,35 +629,23 @@ private[python] class PythonMLLibAPI extends Serializable {
}
/**
- * Java stub for ChiSqSelector.fit() when the seletion type is KBest. This stub returns a
+ * Java stub for ChiSqSelector.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
- def fitChiSqSelectorKBest(numTopFeatures: Int,
- data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
- new ChiSqSelector().setNumTopFeatures(numTopFeatures).fit(data.rdd)
- }
-
- /**
- * Java stub for ChiSqSelector.fit() when the selection type is Percentile. This stub returns a
- * handle to the Java object instead of the content of the Java object.
- * Extra care needs to be taken in the Python code to ensure it gets freed on
- * exit; see the Py4J documentation.
- */
- def fitChiSqSelectorPercentile(percentile: Double,
- data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
- new ChiSqSelector().setPercentile(percentile).fit(data.rdd)
- }
-
- /**
- * Java stub for ChiSqSelector.fit() when the selection type is FPR. This stub returns a
- * handle to the Java object instead of the content of the Java object.
- * Extra care needs to be taken in the Python code to ensure it gets freed on
- * exit; see the Py4J documentation.
- */
- def fitChiSqSelectorFPR(alpha: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
- new ChiSqSelector().setAlpha(alpha).fit(data.rdd)
+ def fitChiSqSelector(
+ selectorType: String,
+ numTopFeatures: Int,
+ percentile: Double,
+ alpha: Double,
+ data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
+ new ChiSqSelector()
+ .setSelectorType(selectorType)
+ .setNumTopFeatures(numTopFeatures)
+ .setPercentile(percentile)
+ .setAlpha(alpha)
+ .fit(data.rdd)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index f68a017184..0f7c6e8bc0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -32,12 +32,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}
-@Since("2.1.0")
-private[spark] object ChiSqSelectorType extends Enumeration {
- type SelectorType = Value
- val KBest, Percentile, FPR = Value
-}
-
/**
* Chi Squared selector model.
*
@@ -166,19 +160,18 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
/**
* Creates a ChiSquared feature selector.
- * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
- * `KBest` chooses the `k` top features according to a chi-squared test.
- * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `FPR` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `KBest`, the default number of top features is 50.
- * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+ * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
+ * `kbest` chooses the `k` top features according to a chi-squared test.
+ * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `fpr` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `kbest`, the default number of top features is 50.
*/
@Since("1.3.0")
class ChiSqSelector @Since("2.1.0") () extends Serializable {
var numTopFeatures: Int = 50
var percentile: Double = 0.1
var alpha: Double = 0.05
- var selectorType = ChiSqSelectorType.KBest
+ var selectorType = ChiSqSelector.KBest
/**
* The is the same to call this() and setNumTopFeatures(numTopFeatures)
@@ -192,7 +185,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
@Since("1.6.0")
def setNumTopFeatures(value: Int): this.type = {
numTopFeatures = value
- selectorType = ChiSqSelectorType.KBest
this
}
@@ -200,7 +192,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
def setPercentile(value: Double): this.type = {
require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]")
percentile = value
- selectorType = ChiSqSelectorType.Percentile
this
}
@@ -208,12 +199,13 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
def setAlpha(value: Double): this.type = {
require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]")
alpha = value
- selectorType = ChiSqSelectorType.FPR
this
}
@Since("2.1.0")
- def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = {
+ def setSelectorType(value: String): this.type = {
+ require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value),
+ s"ChiSqSelector Type: $value was not supported.")
selectorType = value
this
}
@@ -230,11 +222,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
val chiSqTestResult = Statistics.chiSqTest(data)
.zipWithIndex.sortBy { case (res, _) => -res.statistic }
val features = selectorType match {
- case ChiSqSelectorType.KBest => chiSqTestResult
+ case ChiSqSelector.KBest => chiSqTestResult
.take(numTopFeatures)
- case ChiSqSelectorType.Percentile => chiSqTestResult
+ case ChiSqSelector.Percentile => chiSqTestResult
.take((chiSqTestResult.length * percentile).toInt)
- case ChiSqSelectorType.FPR => chiSqTestResult
+ case ChiSqSelector.FPR => chiSqTestResult
.filter{ case (res, _) => res.pValue < alpha }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
@@ -244,3 +236,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
}
}
+@Since("2.1.0")
+object ChiSqSelector {
+
+ /** String name for `kbest` selector type. */
+ private[spark] val KBest: String = "kbest"
+
+ /** String name for `percentile` selector type. */
+ private[spark] val Percentile: String = "percentile"
+
+ /** String name for `fpr` selector type. */
+ private[spark] val FPR: String = "fpr"
+
+ /** Set of selector type and param pairs that ChiSqSelector supports. */
+ private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures",
+ Percentile -> "percentile", FPR -> "alpha")
+
+ /** Set of selector types that ChiSqSelector supports. */
+ private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index e0293dbc4b..6b56e42002 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -50,6 +50,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
.toDF("label", "data", "preFilteredData")
val selector = new ChiSqSelector()
+ .setSelectorType("kbest")
.setNumTopFeatures(1)
.setFeaturesCol("data")
.setLabelCol("label")
@@ -60,12 +61,28 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
assert(vec1 ~== vec2 absTol 1e-1)
}
- selector.setPercentile(0.34).fit(df).transform(df)
- .select("filtered", "preFilteredData").collect().foreach {
- case Row(vec1: Vector, vec2: Vector) =>
- assert(vec1 ~== vec2 absTol 1e-1)
- }
+ selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df)
+ .select("filtered", "preFilteredData").collect().foreach {
+ case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 ~== vec2 absTol 1e-1)
+ }
+
+ val preFilteredData2 = Seq(
+ Vectors.dense(8.0, 7.0),
+ Vectors.dense(0.0, 9.0),
+ Vectors.dense(0.0, 9.0),
+ Vectors.dense(8.0, 9.0)
+ )
+ val df2 = sc.parallelize(data.zip(preFilteredData2))
+ .map(x => (x._1.label, x._1.features, x._2))
+ .toDF("label", "data", "preFilteredData")
+
+ selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2)
+ .select("filtered", "preFilteredData").collect().foreach {
+ case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 ~== vec2 absTol 1e-1)
+ }
}
test("ChiSqSelector read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index e181a544f7..ec23a4aa73 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -76,7 +76,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(2.0, Vectors.dense(Array(9.0))))
- val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData)
+ val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData)
val filteredData = labeledDiscreteData.map { lp =>
LabeledPoint(lp.label, model.transform(lp.features))
}.collect().toSet
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index c45434f1a5..12a13849dc 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2586,39 +2586,68 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
.. versionadded:: 2.0.0
"""
+ selectorType = Param(Params._dummy(), "selectorType",
+ "The selector type of the ChisqSelector. " +
+ "Supported options: kbest (default), percentile and fpr.",
+ typeConverter=TypeConverters.toString)
+
numTopFeatures = \
Param(Params._dummy(), "numTopFeatures",
"Number of features that selector will select, ordered by statistics value " +
"descending. If the number of features is < numTopFeatures, then this will select " +
"all features.", typeConverter=TypeConverters.toInt)
+ percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " +
+ "will select, ordered by statistics value descending.",
+ typeConverter=TypeConverters.toFloat)
+
+ alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.",
+ typeConverter=TypeConverters.toFloat)
+
@keyword_only
- def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"):
+ def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None,
+ labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05):
"""
- __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label")
+ __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
+ labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05)
"""
super(ChiSqSelector, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
- self._setDefault(numTopFeatures=50)
+ self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.0.0")
def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
- labelCol="labels"):
+ labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05):
"""
- setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\
- labelCol="labels")
+ setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
+ labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05)
Sets params for this ChiSqSelector.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
+ @since("2.1.0")
+ def setSelectorType(self, value):
+ """
+ Sets the value of :py:attr:`selectorType`.
+ """
+ return self._set(selectorType=value)
+
+ @since("2.1.0")
+ def getSelectorType(self):
+ """
+ Gets the value of selectorType or its default value.
+ """
+ return self.getOrDefault(self.selectorType)
+
@since("2.0.0")
def setNumTopFeatures(self, value):
"""
Sets the value of :py:attr:`numTopFeatures`.
+ Only applicable when selectorType = "kbest".
"""
return self._set(numTopFeatures=value)
@@ -2629,6 +2658,36 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
"""
return self.getOrDefault(self.numTopFeatures)
+ @since("2.1.0")
+ def setPercentile(self, value):
+ """
+ Sets the value of :py:attr:`percentile`.
+ Only applicable when selectorType = "percentile".
+ """
+ return self._set(percentile=value)
+
+ @since("2.1.0")
+ def getPercentile(self):
+ """
+ Gets the value of percentile or its default value.
+ """
+ return self.getOrDefault(self.percentile)
+
+ @since("2.1.0")
+ def setAlpha(self, value):
+ """
+ Sets the value of :py:attr:`alpha`.
+ Only applicable when selectorType = "fpr".
+ """
+ return self._set(alpha=value)
+
+ @since("2.1.0")
+ def getAlpha(self):
+ """
+ Gets the value of alpha or its default value.
+ """
+ return self.getOrDefault(self.alpha)
+
def _create_model(self, java_model):
return ChiSqSelectorModel(java_model)
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 077c11370e..4aea81840a 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -271,22 +271,14 @@ class ChiSqSelectorModel(JavaVectorTransformer):
return JavaVectorTransformer.transform(self, vector)
-class ChiSqSelectorType:
- """
- This class defines the selector types of Chi Square Selector.
- """
- KBest, Percentile, FPR = range(3)
-
-
class ChiSqSelector(object):
"""
Creates a ChiSquared feature selector.
The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
- `KBest` chooses the `k` top features according to a chi-squared test.
- `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
- `FPR` chooses all features whose false positive rate meets some threshold.
- By default, the selection method is `KBest`, the default number of top features is 50.
- User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+ `kbest` chooses the `k` top features according to a chi-squared test.
+ `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ `fpr` chooses all features whose false positive rate meets some threshold.
+ By default, the selection method is `kbest`, the default number of top features is 50.
>>> data = [
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
@@ -299,7 +291,8 @@ class ChiSqSelector(object):
SparseVector(1, {0: 6.0})
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
DenseVector([5.0])
- >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data))
+ >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit(
+ ... sc.parallelize(data))
>>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
SparseVector(1, {0: 6.0})
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
@@ -310,41 +303,52 @@ class ChiSqSelector(object):
... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]),
... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0])
... ]
- >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data))
+ >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data))
>>> model.transform(DenseVector([1.0,2.0,3.0,4.0]))
DenseVector([4.0])
.. versionadded:: 1.4.0
"""
- def __init__(self, numTopFeatures=50):
+ def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05):
self.numTopFeatures = numTopFeatures
- self.selectorType = ChiSqSelectorType.KBest
+ self.selectorType = selectorType
+ self.percentile = percentile
+ self.alpha = alpha
@since('2.1.0')
def setNumTopFeatures(self, numTopFeatures):
"""
- set numTopFeature for feature selection by number of top features
+ set numTopFeature for feature selection by number of top features.
+ Only applicable when selectorType = "kbest".
"""
self.numTopFeatures = int(numTopFeatures)
- self.selectorType = ChiSqSelectorType.KBest
return self
@since('2.1.0')
def setPercentile(self, percentile):
"""
- set percentile [0.0, 1.0] for feature selection by percentile
+ set percentile [0.0, 1.0] for feature selection by percentile.
+ Only applicable when selectorType = "percentile".
"""
self.percentile = float(percentile)
- self.selectorType = ChiSqSelectorType.Percentile
return self
@since('2.1.0')
def setAlpha(self, alpha):
"""
- set alpha [0.0, 1.0] for feature selection by FPR
+ set alpha [0.0, 1.0] for feature selection by FPR.
+ Only applicable when selectorType = "fpr".
"""
self.alpha = float(alpha)
- self.selectorType = ChiSqSelectorType.FPR
+ return self
+
+ @since('2.1.0')
+ def setSelectorType(self, selectorType):
+ """
+ set the selector type of the ChisqSelector.
+ Supported options: "kbest" (default), "percentile" and "fpr".
+ """
+ self.selectorType = str(selectorType)
return self
@since('1.4.0')
@@ -357,15 +361,8 @@ class ChiSqSelector(object):
treated as categorical for each distinct value.
Apply feature discretizer before using this function.
"""
- if self.selectorType == ChiSqSelectorType.KBest:
- jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data)
- elif self.selectorType == ChiSqSelectorType.Percentile:
- jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data)
- elif self.selectorType == ChiSqSelectorType.FPR:
- jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data)
- else:
- raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and"
- " FPR(2), the current value is: %s" % self.selectorType)
+ jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures,
+ self.percentile, self.alpha, data)
return ChiSqSelectorModel(jmodel)