diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-05-08 15:48:39 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-05-08 15:48:39 -0700 |
commit | 35c9599b94de759204ed33cdd46d8ee108bccd86 (patch) | |
tree | 4e2acabba806470d73370105a52682cd35ec0628 /python | |
parent | 6dad76e5eba3c2925bfc9d142f31f7c2dc649886 (diff) | |
download | spark-35c9599b94de759204ed33cdd46d8ee108bccd86.tar.gz spark-35c9599b94de759204ed33cdd46d8ee108bccd86.tar.bz2 spark-35c9599b94de759204ed33cdd46d8ee108bccd86.zip |
[SPARK-5913] [MLLIB] Python API for ChiSqSelector
Add a Python API for mllib.feature.ChiSqSelector
https://issues.apache.org/jira/browse/SPARK-5913
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #5939 from yanboliang/spark-5913 and squashes the following commits:
cdaac99 [Yanbo Liang] Python API for ChiSqSelector
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/feature.py | 59 |
1 files changed, 57 insertions, 2 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 1140539a24..aac305db6c 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -33,10 +33,12 @@ from py4j.protocol import Py4JJavaError from pyspark import SparkContext from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, _convert_to_vector +from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', - 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] + 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', + 'ChiSqSelector', 'ChiSqSelectorModel'] class VectorTransformer(object): @@ -199,6 +201,59 @@ class StandardScaler(object): return StandardScalerModel(jmodel) +class ChiSqSelectorModel(JavaVectorTransformer): + """ + .. note:: Experimental + + Represents a Chi Squared selector model. + """ + def transform(self, vector): + """ + Applies transformation on a vector. + + :param vector: Vector or RDD of Vector to be transformed. + :return: transformed vector. + """ + return JavaVectorTransformer.transform(self, vector) + + +class ChiSqSelector(object): + """ + .. note:: Experimental + + Creates a ChiSquared feature selector. + + >>> data = [ + ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), + ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), + ... LabeledPoint(1.0, [0.0, 9.0, 8.0]), + ... LabeledPoint(2.0, [8.0, 9.0, 5.0]) + ... ] + >>> model = ChiSqSelector(1).fit(sc.parallelize(data)) + >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) + SparseVector(1, {0: 6.0}) + >>> model.transform(DenseVector([8.0, 9.0, 5.0])) + DenseVector([5.0]) + """ + def __init__(self, numTopFeatures): + """ + :param numTopFeatures: number of features that selector will select. + """ + self.numTopFeatures = int(numTopFeatures) + + def fit(self, data): + """ + Returns a ChiSquared feature selector. + + :param data: an `RDD[LabeledPoint]` containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + Apply feature discretizer before using this function. + """ + jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data) + return ChiSqSelectorModel(jmodel) + + class HashingTF(object): """ .. note:: Experimental |