aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml')
-rwxr-xr-x[-rw-r--r--]python/pyspark/ml/feature.py38
-rwxr-xr-x[-rw-r--r--]python/pyspark/ml/tests.py7
2 files changed, 29 insertions, 16 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index f21e3062ef..d2989fa4cd 100644..100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1738,28 +1738,23 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
"comparison over the stop words", typeConverter=TypeConverters.toBoolean)
@keyword_only
- def __init__(self, inputCol=None, outputCol=None, stopWords=None,
- caseSensitive=False):
+ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
"""
- __init__(self, inputCol=None, outputCol=None, stopWords=None,\
- caseSensitive=false)
+ __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
"""
super(StopWordsRemover, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
self.uid)
- stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
- defaultStopWords = list(stopWordsObj.English())
- self._setDefault(stopWords=defaultStopWords, caseSensitive=False)
+ self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
+ caseSensitive=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
- def setParams(self, inputCol=None, outputCol=None, stopWords=None,
- caseSensitive=False):
+ def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
"""
- setParams(self, inputCol="input", outputCol="output", stopWords=None,\
- caseSensitive=false)
+ setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
Sets params for this StopWordRemover.
"""
kwargs = self.setParams._input_kwargs
@@ -1768,31 +1763,42 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
@since("1.6.0")
def setStopWords(self, value):
"""
- Specify the stopwords to be filtered.
+ Sets the value of :py:attr:`stopWords`.
"""
return self._set(stopWords=value)
@since("1.6.0")
def getStopWords(self):
"""
- Get the stopwords.
+ Gets the value of :py:attr:`stopWords` or its default value.
"""
return self.getOrDefault(self.stopWords)
@since("1.6.0")
def setCaseSensitive(self, value):
"""
- Set whether to do a case sensitive comparison over the stop words
+ Sets the value of :py:attr:`caseSensitive`.
"""
return self._set(caseSensitive=value)
@since("1.6.0")
def getCaseSensitive(self):
"""
- Get whether to do a case sensitive comparison over the stop words.
+ Gets the value of :py:attr:`caseSensitive` or its default value.
"""
return self.getOrDefault(self.caseSensitive)
+ @staticmethod
+ @since("2.0.0")
+ def loadDefaultStopWords(language):
+ """
+ Loads the default stop words for the given language.
+ Supported languages: danish, dutch, english, finnish, french, german, hungarian,
+ italian, norwegian, portuguese, russian, spanish, swedish, turkish
+ """
+ stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWordsRemover
+ return list(stopWordsObj.loadDefaultStopWords(language))
+
@inherit_doc
@ignore_unicode_prefix
@@ -1843,7 +1849,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
@since("1.3.0")
def setParams(self, inputCol=None, outputCol=None):
"""
- setParams(self, inputCol="input", outputCol="output")
+ setParams(self, inputCol=None, outputCol=None)
Sets params for this Tokenizer.
"""
kwargs = self.setParams._input_kwargs
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 78ec96af8a..ad1631fb5b 100644..100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -417,6 +417,13 @@ class FeatureTests(PySparkTestCase):
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a"])
+ # with language selection
+ stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
+ dataset = sqlContext.createDataFrame([Row(input=["acaba", "ama", "biri"])])
+ stopWordRemover.setStopWords(stopwords)
+ self.assertEqual(stopWordRemover.getStopWords(), stopwords)
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEqual(transformedDF.head().output, [])
def test_count_vectorizer_with_binary(self):
sqlContext = SQLContext(self.sc)