aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-05-08 15:48:39 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-08 15:48:39 -0700
commit35c9599b94de759204ed33cdd46d8ee108bccd86 (patch)
tree4e2acabba806470d73370105a52682cd35ec0628 /python
parent6dad76e5eba3c2925bfc9d142f31f7c2dc649886 (diff)
downloadspark-35c9599b94de759204ed33cdd46d8ee108bccd86.tar.gz
spark-35c9599b94de759204ed33cdd46d8ee108bccd86.tar.bz2
spark-35c9599b94de759204ed33cdd46d8ee108bccd86.zip
[SPARK-5913] [MLLIB] Python API for ChiSqSelector
Add a Python API for mllib.feature.ChiSqSelector https://issues.apache.org/jira/browse/SPARK-5913 Author: Yanbo Liang <ybliang8@gmail.com> Closes #5939 from yanboliang/spark-5913 and squashes the following commits: cdaac99 [Yanbo Liang] Python API for ChiSqSelector
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/feature.py59
1 files changed, 57 insertions, 2 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 1140539a24..aac305db6c 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -33,10 +33,12 @@ from py4j.protocol import Py4JJavaError
from pyspark import SparkContext
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Vectors, _convert_to_vector
+from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector
+from pyspark.mllib.regression import LabeledPoint
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
- 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
+ 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
+ 'ChiSqSelector', 'ChiSqSelectorModel']
class VectorTransformer(object):
@@ -199,6 +201,59 @@ class StandardScaler(object):
return StandardScalerModel(jmodel)
+class ChiSqSelectorModel(JavaVectorTransformer):
+ """
+ .. note:: Experimental
+
+ Represents a Chi Squared selector model.
+ """
+ def transform(self, vector):
+ """
+ Applies transformation on a vector.
+
+ :param vector: Vector or RDD of Vector to be transformed.
+ :return: transformed vector.
+ """
+ return JavaVectorTransformer.transform(self, vector)
+
+
+class ChiSqSelector(object):
+ """
+ .. note:: Experimental
+
+ Creates a ChiSquared feature selector.
+
+ >>> data = [
+ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
+ ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
+ ... LabeledPoint(1.0, [0.0, 9.0, 8.0]),
+ ... LabeledPoint(2.0, [8.0, 9.0, 5.0])
+ ... ]
+ >>> model = ChiSqSelector(1).fit(sc.parallelize(data))
+ >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
+ SparseVector(1, {0: 6.0})
+ >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
+ DenseVector([5.0])
+ """
+ def __init__(self, numTopFeatures):
+ """
+ :param numTopFeatures: number of features that selector will select.
+ """
+ self.numTopFeatures = int(numTopFeatures)
+
+ def fit(self, data):
+ """
+ Returns a ChiSquared feature selector.
+
+ :param data: an `RDD[LabeledPoint]` containing the labeled dataset
+ with categorical features. Real-valued features will be
+ treated as categorical for each distinct value.
+ Apply feature discretizer before using this function.
+ """
+ jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
+ return ChiSqSelectorModel(jmodel)
+
+
class HashingTF(object):
"""
.. note:: Experimental