aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/feature.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/feature.py')
-rw-r--r--python/pyspark/mllib/feature.py59
1 files changed, 57 insertions, 2 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 1140539a24..aac305db6c 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -33,10 +33,12 @@ from py4j.protocol import Py4JJavaError
from pyspark import SparkContext
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Vectors, _convert_to_vector
+from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector
+from pyspark.mllib.regression import LabeledPoint
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
- 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
+ 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
+ 'ChiSqSelector', 'ChiSqSelectorModel']
class VectorTransformer(object):
@@ -199,6 +201,59 @@ class StandardScaler(object):
return StandardScalerModel(jmodel)
+class ChiSqSelectorModel(JavaVectorTransformer):
+ """
+ .. note:: Experimental
+
+ Represents a Chi Squared selector model.
+ """
+ def transform(self, vector):
+ """
+ Applies transformation on a vector.
+
+ :param vector: Vector or RDD of Vector to be transformed.
+ :return: transformed vector.
+ """
+ return JavaVectorTransformer.transform(self, vector)
+
+
+class ChiSqSelector(object):
+ """
+ .. note:: Experimental
+
+ Creates a ChiSquared feature selector.
+
+ >>> data = [
+ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
+ ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
+ ... LabeledPoint(1.0, [0.0, 9.0, 8.0]),
+ ... LabeledPoint(2.0, [8.0, 9.0, 5.0])
+ ... ]
+ >>> model = ChiSqSelector(1).fit(sc.parallelize(data))
+ >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
+ SparseVector(1, {0: 6.0})
+ >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
+ DenseVector([5.0])
+ """
+ def __init__(self, numTopFeatures):
+ """
+ :param numTopFeatures: number of features that selector will select.
+ """
+ self.numTopFeatures = int(numTopFeatures)
+
+ def fit(self, data):
+ """
+ Returns a ChiSquared feature selector.
+
+ :param data: an `RDD[LabeledPoint]` containing the labeled dataset
+ with categorical features. Real-valued features will be
+ treated as categorical for each distinct value.
+ Apply feature discretizer before using this function.
+ """
+ jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
+ return ChiSqSelectorModel(jmodel)
+
+
class HashingTF(object):
"""
.. note:: Experimental