From 8beab68152348c44cf2f89850f792f164b06470d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 26 Jan 2016 11:56:46 -0800 Subject: [SPARK-11923][ML] Python API for ml.feature.ChiSqSelector https://issues.apache.org/jira/browse/SPARK-11923 Author: Xusen Yin Closes #10186 from yinxusen/SPARK-11923. --- python/pyspark/ml/feature.py | 98 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) (limited to 'python') diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f139d81bc4..32f324685a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -33,7 +33,8 @@ __all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', + 'ChiSqSelector', 'ChiSqSelectorModel'] @inherit_doc @@ -2237,6 +2238,101 @@ class RFormulaModel(JavaModel): """ +@inherit_doc +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol): + """ + .. note:: Experimental + + Chi-Squared feature selection, which selects categorical features to use for predicting a + categorical label. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), + ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), + ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)], + ... ["features", "label"]) + >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") + >>> model = selector.fit(df) + >>> model.transform(df).head().selectedFeatures + DenseVector([1.0]) + >>> model.selectedFeatures + [3] + + .. versionadded:: 2.0.0 + """ + + # a placeholder to make it appear in the generated doc + numTopFeatures = \ + Param(Params._dummy(), "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will select " + + "all features.") + + @keyword_only + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): + """ + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") + """ + super(ChiSqSelector, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) + self.numTopFeatures = \ + Param(self, "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will " + + "select all features.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="labels"): + """ + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\ + labelCol="labels") + Sets params for this ChiSqSelector. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setNumTopFeatures(self, value): + """ + Sets the value of :py:attr:`numTopFeatures`. + """ + self._paramMap[self.numTopFeatures] = value + return self + + @since("2.0.0") + def getNumTopFeatures(self): + """ + Gets the value of numTopFeatures or its default value. + """ + return self.getOrDefault(self.numTopFeatures) + + def _create_model(self, java_model): + return ChiSqSelectorModel(java_model) + + +class ChiSqSelectorModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by ChiSqSelector. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def selectedFeatures(self): + """ + List of indices to select (filter). Must be ordered asc. + """ + return self._call_java("selectedFeatures") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext -- cgit v1.2.3