aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r--python/pyspark/mllib/feature.py59
1 files changed, 28 insertions, 31 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 077c11370e..4aea81840a 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -271,22 +271,14 @@ class ChiSqSelectorModel(JavaVectorTransformer):
return JavaVectorTransformer.transform(self, vector)
-class ChiSqSelectorType:
- """
- This class defines the selector types of Chi Square Selector.
- """
- KBest, Percentile, FPR = range(3)
-
-
class ChiSqSelector(object):
"""
Creates a ChiSquared feature selector.
The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
- `KBest` chooses the `k` top features according to a chi-squared test.
- `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
- `FPR` chooses all features whose false positive rate meets some threshold.
- By default, the selection method is `KBest`, the default number of top features is 50.
- User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+ `kbest` chooses the `k` top features according to a chi-squared test.
+ `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ `fpr` chooses all features whose false positive rate meets some threshold.
+ By default, the selection method is `kbest`, the default number of top features is 50.
>>> data = [
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
@@ -299,7 +291,8 @@ class ChiSqSelector(object):
SparseVector(1, {0: 6.0})
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
DenseVector([5.0])
- >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data))
+ >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit(
+ ... sc.parallelize(data))
>>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
SparseVector(1, {0: 6.0})
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
@@ -310,41 +303,52 @@ class ChiSqSelector(object):
... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]),
... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0])
... ]
- >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data))
+ >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data))
>>> model.transform(DenseVector([1.0,2.0,3.0,4.0]))
DenseVector([4.0])
.. versionadded:: 1.4.0
"""
- def __init__(self, numTopFeatures=50):
+ def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05):
self.numTopFeatures = numTopFeatures
- self.selectorType = ChiSqSelectorType.KBest
+ self.selectorType = selectorType
+ self.percentile = percentile
+ self.alpha = alpha
@since('2.1.0')
def setNumTopFeatures(self, numTopFeatures):
"""
- set numTopFeature for feature selection by number of top features
+ set numTopFeature for feature selection by number of top features.
+ Only applicable when selectorType = "kbest".
"""
self.numTopFeatures = int(numTopFeatures)
- self.selectorType = ChiSqSelectorType.KBest
return self
@since('2.1.0')
def setPercentile(self, percentile):
"""
- set percentile [0.0, 1.0] for feature selection by percentile
+ set percentile [0.0, 1.0] for feature selection by percentile.
+ Only applicable when selectorType = "percentile".
"""
self.percentile = float(percentile)
- self.selectorType = ChiSqSelectorType.Percentile
return self
@since('2.1.0')
def setAlpha(self, alpha):
"""
- set alpha [0.0, 1.0] for feature selection by FPR
+ set alpha [0.0, 1.0] for feature selection by FPR.
+ Only applicable when selectorType = "fpr".
"""
self.alpha = float(alpha)
- self.selectorType = ChiSqSelectorType.FPR
+ return self
+
+ @since('2.1.0')
+ def setSelectorType(self, selectorType):
+ """
+ set the selector type of the ChisqSelector.
+ Supported options: "kbest" (default), "percentile" and "fpr".
+ """
+ self.selectorType = str(selectorType)
return self
@since('1.4.0')
@@ -357,15 +361,8 @@ class ChiSqSelector(object):
treated as categorical for each distinct value.
Apply feature discretizer before using this function.
"""
- if self.selectorType == ChiSqSelectorType.KBest:
- jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data)
- elif self.selectorType == ChiSqSelectorType.Percentile:
- jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data)
- elif self.selectorType == ChiSqSelectorType.FPR:
- jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data)
- else:
- raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and"
- " FPR(2), the current value is: %s" % self.selectorType)
+ jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures,
+ self.percentile, self.alpha, data)
return ChiSqSelectorModel(jmodel)