From 2f6fd5256c6650868916a3eefaa0beb091187cbb Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 8 Sep 2015 22:13:05 -0700 Subject: [SPARK-9654] [ML] [PYSPARK] Add IndexToString to PySpark Adds IndexToString to PySpark. Author: Holden Karau Closes #7976 from holdenk/SPARK-9654-add-string-indexer-inverse-in-pyspark. --- python/pyspark/ml/feature.py | 74 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 4 deletions(-) (limited to 'python/pyspark/ml/feature.py') diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a7c5b2b62e..8c26cfbd5a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,10 +27,11 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', - 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', - 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover'] + 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', + 'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', + 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', + 'StopWordsRemover'] @inherit_doc @@ -934,6 +935,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> itd = inverter.transform(td) + >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), + ... key=lambda x: x[0]) + [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] """ @keyword_only @@ -965,6 +971,66 @@ class StringIndexerModel(JavaModel): Model fitted by StringIndexer. """ + @property + def labels(self): + """ + Ordered list of labels, corresponding to indices to be assigned. + """ + return self._java_obj.labels + + +@inherit_doc +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A :py:class:`Transformer` that maps a column of string indices back to a new column of + corresponding string values using either the ML attributes of the input column, or if + provided using the labels supplied by the user. + All original columns are kept during transformation. + See L{StringIndexer} for converting strings into indices. + """ + + # a placeholder to make the labels show up in generated doc + labels = Param(Params._dummy(), "labels", + "Optional array of labels to be provided by the user, if not supplied or " + + "empty, column metadata is read for labels") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, labels=None): + """ + __init__(self, inputCol=None, outputCol=None, labels=None) + """ + super(IndexToString, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", + self.uid) + self.labels = Param(self, "labels", + "Optional array of labels to be provided by the user, if not " + + "supplied or empty, column metadata is read for labels") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, labels=None): + """ + setParams(self, inputCol=None, outputCol=None, labels=None) + Sets params for this IndexToString. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setLabels(self, value): + """ + Sets the value of :py:attr:`labels`. + """ + self._paramMap[self.labels] = value + return self + + def getLabels(self): + """ + Gets the value of :py:attr:`labels` or its default value. + """ + return self.getOrDefault(self.labels) class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): -- cgit v1.2.3