aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/feature.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/feature.py')
-rw-r--r--python/pyspark/mllib/feature.py71
1 files changed, 66 insertions, 5 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 5d99644fca..077c11370e 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -271,11 +271,22 @@ class ChiSqSelectorModel(JavaVectorTransformer):
return JavaVectorTransformer.transform(self, vector)
+class ChiSqSelectorType:
+ """
+ This class defines the selector types of Chi Square Selector.
+ """
+ KBest, Percentile, FPR = range(3)
+
+
class ChiSqSelector(object):
"""
Creates a ChiSquared feature selector.
-
- :param numTopFeatures: number of features that selector will select.
+ The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
+ `KBest` chooses the `k` top features according to a chi-squared test.
+ `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ `FPR` chooses all features whose false positive rate meets some threshold.
+ By default, the selection method is `KBest`, the default number of top features is 50.
+ User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
>>> data = [
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
@@ -283,16 +294,58 @@ class ChiSqSelector(object):
... LabeledPoint(1.0, [0.0, 9.0, 8.0]),
... LabeledPoint(2.0, [8.0, 9.0, 5.0])
... ]
- >>> model = ChiSqSelector(1).fit(sc.parallelize(data))
+ >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data))
+ >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
+ SparseVector(1, {0: 6.0})
+ >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
+ DenseVector([5.0])
+ >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data))
>>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
SparseVector(1, {0: 6.0})
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
DenseVector([5.0])
+ >>> data = [
+ ... LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})),
+ ... LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})),
+ ... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]),
+ ... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0])
+ ... ]
+ >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data))
+ >>> model.transform(DenseVector([1.0,2.0,3.0,4.0]))
+ DenseVector([4.0])
.. versionadded:: 1.4.0
"""
- def __init__(self, numTopFeatures):
+ def __init__(self, numTopFeatures=50):
+ self.numTopFeatures = numTopFeatures
+ self.selectorType = ChiSqSelectorType.KBest
+
+ @since('2.1.0')
+ def setNumTopFeatures(self, numTopFeatures):
+ """
+ set numTopFeature for feature selection by number of top features
+ """
self.numTopFeatures = int(numTopFeatures)
+ self.selectorType = ChiSqSelectorType.KBest
+ return self
+
+ @since('2.1.0')
+ def setPercentile(self, percentile):
+ """
+ set percentile [0.0, 1.0] for feature selection by percentile
+ """
+ self.percentile = float(percentile)
+ self.selectorType = ChiSqSelectorType.Percentile
+ return self
+
+ @since('2.1.0')
+ def setAlpha(self, alpha):
+ """
+ set alpha [0.0, 1.0] for feature selection by FPR
+ """
+ self.alpha = float(alpha)
+ self.selectorType = ChiSqSelectorType.FPR
+ return self
@since('1.4.0')
def fit(self, data):
@@ -304,7 +357,15 @@ class ChiSqSelector(object):
treated as categorical for each distinct value.
Apply feature discretizer before using this function.
"""
- jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
+ if self.selectorType == ChiSqSelectorType.KBest:
+ jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data)
+ elif self.selectorType == ChiSqSelectorType.Percentile:
+ jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data)
+ elif self.selectorType == ChiSqSelectorType.FPR:
+ jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data)
+ else:
+ raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and"
+ " FPR(2), the current value is: %s" % self.selectorType)
return ChiSqSelectorModel(jmodel)