aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/feature.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rw-r--r--python/pyspark/ml/feature.py31
1 files changed, 26 insertions, 5 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 1c423486be..71dc636b83 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -920,7 +920,7 @@ class StandardScalerModel(JavaModel):
@inherit_doc
-class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
+class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid):
"""
.. note:: Experimental
@@ -943,19 +943,20 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
"""
@keyword_only
- def __init__(self, inputCol=None, outputCol=None):
+ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
"""
- __init__(self, inputCol=None, outputCol=None)
+ __init__(self, inputCol=None, outputCol=None, handleInvalid="error")
"""
super(StringIndexer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
+ self._setDefault(handleInvalid="error")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
- def setParams(self, inputCol=None, outputCol=None):
+ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
"""
- setParams(self, inputCol=None, outputCol=None)
+ setParams(self, inputCol=None, outputCol=None, handleInvalid="error")
Sets params for this StringIndexer.
"""
kwargs = self.setParams._input_kwargs
@@ -1235,6 +1236,10 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
>>> model = indexer.fit(df)
>>> model.transform(df).head().indexed
DenseVector([1.0, 0.0])
+ >>> model.numFeatures
+ 2
+ >>> model.categoryMaps
+ {0: {0.0: 0, -1.0: 1}}
>>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test
DenseVector([0.0, 1.0])
>>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"}
@@ -1297,6 +1302,22 @@ class VectorIndexerModel(JavaModel):
Model fitted by VectorIndexer.
"""
+ @property
+ def numFeatures(self):
+ """
+ Number of features, i.e., length of Vectors which this transforms.
+ """
+ return self._call_java("numFeatures")
+
+ @property
+ def categoryMaps(self):
+ """
+ Feature value index. Keys are categorical feature indices (column indices).
+ Values are maps from original features values to 0-based category indices.
+ If a feature is not in this map, it is treated as continuous.
+ """
+ return self._call_java("javaCategoryMaps")
+
@inherit_doc
class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol):