aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-09-08 22:13:05 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-09-08 22:13:05 -0700
commit2f6fd5256c6650868916a3eefaa0beb091187cbb (patch)
treeb76233d6ff3d4c8209ca3a01029b5f2438126b3d /python
parent0e2f2163314972f6be18e3453c64314d1bee7bb9 (diff)
downloadspark-2f6fd5256c6650868916a3eefaa0beb091187cbb.tar.gz
spark-2f6fd5256c6650868916a3eefaa0beb091187cbb.tar.bz2
spark-2f6fd5256c6650868916a3eefaa0beb091187cbb.zip
[SPARK-9654] [ML] [PYSPARK] Add IndexToString to PySpark
Adds IndexToString to PySpark. Author: Holden Karau <holden@pigscanfly.ca> Closes #7976 from holdenk/SPARK-9654-add-string-indexer-inverse-in-pyspark.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/feature.py74
-rw-r--r--python/pyspark/ml/wrapper.py3
2 files changed, 72 insertions, 5 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index a7c5b2b62e..8c26cfbd5a 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -27,10 +27,11 @@ from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel',
- 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
- 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
- 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
- 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']
+ 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion',
+ 'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel',
+ 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
+ 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel',
+ 'StopWordsRemover']
@inherit_doc
@@ -934,6 +935,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
... key=lambda x: x[0])
[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
+ >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels())
+ >>> itd = inverter.transform(td)
+ >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
+ ... key=lambda x: x[0])
+ [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
"""
@keyword_only
@@ -965,6 +971,66 @@ class StringIndexerModel(JavaModel):
Model fitted by StringIndexer.
"""
+ @property
+ def labels(self):
+ """
+ Ordered list of labels, corresponding to indices to be assigned.
+ """
+ return self._java_obj.labels
+
+
+@inherit_doc
+class IndexToString(JavaTransformer, HasInputCol, HasOutputCol):
+ """
+ .. note:: Experimental
+
+ A :py:class:`Transformer` that maps a column of string indices back to a new column of
+ corresponding string values using either the ML attributes of the input column, or if
+ provided using the labels supplied by the user.
+ All original columns are kept during transformation.
+ See L{StringIndexer} for converting strings into indices.
+ """
+
+ # a placeholder to make the labels show up in generated doc
+ labels = Param(Params._dummy(), "labels",
+ "Optional array of labels to be provided by the user, if not supplied or " +
+ "empty, column metadata is read for labels")
+
+ @keyword_only
+ def __init__(self, inputCol=None, outputCol=None, labels=None):
+ """
+ __init__(self, inputCol=None, outputCol=None, labels=None)
+ """
+ super(IndexToString, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString",
+ self.uid)
+ self.labels = Param(self, "labels",
+ "Optional array of labels to be provided by the user, if not " +
+ "supplied or empty, column metadata is read for labels")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, inputCol=None, outputCol=None, labels=None):
+ """
+ setParams(self, inputCol=None, outputCol=None, labels=None)
+ Sets params for this IndexToString.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def setLabels(self, value):
+ """
+ Sets the value of :py:attr:`labels`.
+ """
+ self._paramMap[self.labels] = value
+ return self
+
+ def getLabels(self):
+ """
+ Gets the value of :py:attr:`labels` or its default value.
+ """
+ return self.getOrDefault(self.labels)
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 253705bde9..8218c7c5f8 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -136,7 +136,8 @@ class JavaEstimator(Estimator, JavaWrapper):
class JavaTransformer(Transformer, JavaWrapper):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
- implementations.
+ implementations. Subclasses should ensure they have the transformer Java object
+ available as _java_obj.
"""
__metaclass__ = ABCMeta