aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/feature.py
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-09-01 10:48:57 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-01 10:48:57 -0700
commite6e483cc4de740c46398385b03ffe0e662edae39 (patch)
tree652cf519f902aaaf8eecc564791690b395aea81b /python/pyspark/ml/feature.py
parent391e6be0ae883f3ea0fab79463eb8b618af79afb (diff)
downloadspark-e6e483cc4de740c46398385b03ffe0e662edae39.tar.gz
spark-e6e483cc4de740c46398385b03ffe0e662edae39.tar.bz2
spark-e6e483cc4de740c46398385b03ffe0e662edae39.zip
[SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover
Add a python API for the Stop Words Remover. Author: Holden Karau <holden@pigscanfly.ca> Closes #8118 from holdenk/SPARK-9679-python-StopWordsRemover.
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rw-r--r--python/pyspark/ml/feature.py73
1 files changed, 71 insertions, 2 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 0626281e20..d955307e27 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -22,7 +22,7 @@ if sys.version > '3':
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
from pyspark.ml.util import keyword_only
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
@@ -30,7 +30,7 @@ __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF',
'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
- 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
+ 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']
@inherit_doc
@@ -933,6 +933,75 @@ class StringIndexerModel(JavaModel):
"""
+class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
+ """
+ .. note:: Experimental
+
+ A feature transformer that filters out stop words from input.
+ Note: null values from input array are preserved unless adding null to stopWords explicitly.
+ """
+ # a placeholder to make the stopwords show up in generated doc
+ stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out")
+ caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
+ "comparison over the stop words")
+
+ @keyword_only
+ def __init__(self, inputCol=None, outputCol=None, stopWords=None,
+ caseSensitive=False):
+ """
+ __init__(self, inputCol=None, outputCol=None, stopWords=None,\
+ caseSensitive=false)
+ """
+ super(StopWordsRemover, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
+ self.uid)
+ self.stopWords = Param(self, "stopWords", "The words to be filtered out")
+ self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " +
+ "sensitive comparison over the stop words")
+ stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
+ defaultStopWords = stopWordsObj.English()
+ self._setDefault(stopWords=defaultStopWords)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, inputCol=None, outputCol=None, stopWords=None,
+ caseSensitive=False):
+ """
+ setParams(self, inputCol="input", outputCol="output", stopWords=None,\
+ caseSensitive=false)
+ Sets params for this StopWordRemover.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def setStopWords(self, value):
+ """
+ Specify the stopwords to be filtered.
+ """
+ self._paramMap[self.stopWords] = value
+ return self
+
+ def getStopWords(self):
+ """
+ Get the stopwords.
+ """
+ return self.getOrDefault(self.stopWords)
+
+ def setCaseSensitive(self, value):
+ """
+ Set whether to do a case sensitive comparison over the stop words
+ """
+ self._paramMap[self.caseSensitive] = value
+ return self
+
+ def getCaseSensitive(self):
+ """
+ Get whether to do a case sensitive comparison over the stop words.
+ """
+ return self.getOrDefault(self.caseSensitive)
+
+
@inherit_doc
@ignore_unicode_prefix
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):