aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala2
-rw-r--r--python/pyspark/ml/feature.py73
-rw-r--r--python/pyspark/ml/tests.py20
4 files changed, 93 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 5d77ea08db..7da430c7d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp
/**
* stop words list
*/
-private object StopWords {
+private[spark] object StopWords {
/**
* Use the same default stopwords list as scikit-learn.
* The original list can be found from "Glasgow Information Retrieval Group"
* [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]]
*/
- val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again",
+ val English = Array( "a", "about", "above", "across", "after", "afterwards", "again",
"against", "all", "almost", "alone", "along", "already", "also", "although", "always",
"am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
"any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
@@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String)
/** @group getParam */
def getCaseSensitive: Boolean = $(caseSensitive)
- setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false)
+ setDefault(stopWords -> StopWords.English, caseSensitive -> false)
override def transform(dataset: DataFrame): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index f01306f89c..e0d433f566 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("StopWordsRemover with additional words") {
- val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala")
+ val stopWords = StopWords.English ++ Array("python", "scala")
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("filtered")
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 0626281e20..d955307e27 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -22,7 +22,7 @@ if sys.version > '3':
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
from pyspark.ml.util import keyword_only
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
@@ -30,7 +30,7 @@ __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF',
'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
- 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
+ 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']
@inherit_doc
@@ -933,6 +933,75 @@ class StringIndexerModel(JavaModel):
"""
+class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
+ """
+ .. note:: Experimental
+
+ A feature transformer that filters out stop words from input.
+ Note: null values from input array are preserved unless adding null to stopWords explicitly.
+ """
+ # a placeholder to make the stopwords show up in generated doc
+ stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out")
+ caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
+ "comparison over the stop words")
+
+ @keyword_only
+ def __init__(self, inputCol=None, outputCol=None, stopWords=None,
+ caseSensitive=False):
+ """
+ __init__(self, inputCol=None, outputCol=None, stopWords=None,\
+ caseSensitive=false)
+ """
+ super(StopWordsRemover, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
+ self.uid)
+ self.stopWords = Param(self, "stopWords", "The words to be filtered out")
+ self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " +
+ "sensitive comparison over the stop words")
+ stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
+ defaultStopWords = stopWordsObj.English()
+ self._setDefault(stopWords=defaultStopWords)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, inputCol=None, outputCol=None, stopWords=None,
+ caseSensitive=False):
+ """
+ setParams(self, inputCol="input", outputCol="output", stopWords=None,\
+ caseSensitive=false)
+ Sets params for this StopWordRemover.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def setStopWords(self, value):
+ """
+ Specify the stopwords to be filtered.
+ """
+ self._paramMap[self.stopWords] = value
+ return self
+
+ def getStopWords(self):
+ """
+ Get the stopwords.
+ """
+ return self.getOrDefault(self.stopWords)
+
+ def setCaseSensitive(self, value):
+ """
+ Set whether to do a case sensitive comparison over the stop words
+ """
+ self._paramMap[self.caseSensitive] = value
+ return self
+
+ def getCaseSensitive(self):
+ """
+ Get whether to do a case sensitive comparison over the stop words.
+ """
+ return self.getOrDefault(self.caseSensitive)
+
+
@inherit_doc
@ignore_unicode_prefix
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 60e4237293..b892318f50 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -31,7 +31,7 @@ else:
import unittest
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-from pyspark.sql import DataFrame, SQLContext
+from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.param import Param, Params
@@ -258,7 +258,7 @@ class FeatureTests(PySparkTestCase):
def test_ngram(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([
- ([["a", "b", "c", "d", "e"]])], ["input"])
+ Row(input=["a", "b", "c", "d", "e"])])
ngram0 = NGram(n=4, inputCol="input", outputCol="output")
self.assertEqual(ngram0.getN(), 4)
self.assertEqual(ngram0.getInputCol(), "input")
@@ -266,6 +266,22 @@ class FeatureTests(PySparkTestCase):
transformedDF = ngram0.transform(dataset)
self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
+ def test_stopwordsremover(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
+ stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
+ # Default
+ self.assertEquals(stopWordRemover.getInputCol(), "input")
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEquals(transformedDF.head().output, ["panda"])
+ # Custom
+ stopwords = ["panda"]
+ stopWordRemover.setStopWords(stopwords)
+ self.assertEquals(stopWordRemover.getInputCol(), "input")
+ self.assertEquals(stopWordRemover.getStopWords(), stopwords)
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEquals(transformedDF.head().output, ["a"])
+
class HasInducedError(Params):