aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/feature.py2
-rw-r--r--python/pyspark/ml/tests.py3
2 files changed, 4 insertions, 1 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 809a513316..0d8ef1297f 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1765,7 +1765,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
self.uid)
stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
- defaultStopWords = stopWordsObj.English()
+ defaultStopWords = list(stopWordsObj.English())
self._setDefault(stopWords=defaultStopWords, caseSensitive=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 86c0254a2b..85ad949c5a 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -22,6 +22,7 @@ import array
import sys
if sys.version > '3':
xrange = range
+ basestring = str
try:
import xmlrunner
@@ -398,6 +399,8 @@ class FeatureTests(PySparkTestCase):
self.assertEqual(stopWordRemover.getInputCol(), "input")
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["panda"])
+ self.assertEqual(type(stopWordRemover.getStopWords()), list)
+ self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring))
# Custom
stopwords = ["panda"]
stopWordRemover.setStopWords(stopwords)