aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-01-26 11:56:46 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-26 11:56:46 -0800
commit8beab68152348c44cf2f89850f792f164b06470d (patch)
treecf8ba3b8e97b251f60b1ae581aee5362859a14af
parentcbd507d69cea24adfb335d8fe26ab5a13c053ffc (diff)
downloadspark-8beab68152348c44cf2f89850f792f164b06470d.tar.gz
spark-8beab68152348c44cf2f89850f792f164b06470d.tar.bz2
spark-8beab68152348c44cf2f89850f792f164b06470d.zip
[SPARK-11923][ML] Python API for ml.feature.ChiSqSelector
https://issues.apache.org/jira/browse/SPARK-11923 Author: Xusen Yin <yinxusen@gmail.com> Closes #10186 from yinxusen/SPARK-11923.
-rw-r--r--python/pyspark/ml/feature.py98
1 files changed, 97 insertions, 1 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index f139d81bc4..32f324685a 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -33,7 +33,8 @@ __all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel',
'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula',
'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel',
'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer',
- 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel']
+ 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel',
+ 'ChiSqSelector', 'ChiSqSelectorModel']
@inherit_doc
@@ -2237,6 +2238,101 @@ class RFormulaModel(JavaModel):
"""
+@inherit_doc
+class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol):
+ """
+ .. note:: Experimental
+
+ Chi-Squared feature selection, which selects categorical features to use for predicting a
+ categorical label.
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> df = sqlContext.createDataFrame(
+ ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),
+ ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),
+ ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)],
+ ... ["features", "label"])
+ >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")
+ >>> model = selector.fit(df)
+ >>> model.transform(df).head().selectedFeatures
+ DenseVector([1.0])
+ >>> model.selectedFeatures
+ [3]
+
+ .. versionadded:: 2.0.0
+ """
+
+ # a placeholder to make it appear in the generated doc
+ numTopFeatures = \
+ Param(Params._dummy(), "numTopFeatures",
+ "Number of features that selector will select, ordered by statistics value " +
+ "descending. If the number of features is < numTopFeatures, then this will select " +
+ "all features.")
+
+ @keyword_only
+ def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"):
+ """
+ __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label")
+ """
+ super(ChiSqSelector, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
+ self.numTopFeatures = \
+ Param(self, "numTopFeatures",
+ "Number of features that selector will select, ordered by statistics value " +
+ "descending. If the number of features is < numTopFeatures, then this will " +
+ "select all features.")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("2.0.0")
+ def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
+ labelCol="labels"):
+ """
+ setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\
+ labelCol="labels")
+ Sets params for this ChiSqSelector.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ @since("2.0.0")
+ def setNumTopFeatures(self, value):
+ """
+ Sets the value of :py:attr:`numTopFeatures`.
+ """
+ self._paramMap[self.numTopFeatures] = value
+ return self
+
+ @since("2.0.0")
+ def getNumTopFeatures(self):
+ """
+ Gets the value of numTopFeatures or its default value.
+ """
+ return self.getOrDefault(self.numTopFeatures)
+
+ def _create_model(self, java_model):
+ return ChiSqSelectorModel(java_model)
+
+
+class ChiSqSelectorModel(JavaModel):
+ """
+ .. note:: Experimental
+
+ Model fitted by ChiSqSelector.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def selectedFeatures(self):
+ """
+ List of indices to select (filter). Must be ordered asc.
+ """
+ return self._call_java("selectedFeatures")
+
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext