diff options
author | Holden Karau <holden@pigscanfly.ca> | 2015-09-01 10:48:57 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-09-01 10:48:57 -0700 |
commit | e6e483cc4de740c46398385b03ffe0e662edae39 (patch) | |
tree | 652cf519f902aaaf8eecc564791690b395aea81b /python/pyspark/ml/tests.py | |
parent | 391e6be0ae883f3ea0fab79463eb8b618af79afb (diff) | |
download | spark-e6e483cc4de740c46398385b03ffe0e662edae39.tar.gz spark-e6e483cc4de740c46398385b03ffe0e662edae39.tar.bz2 spark-e6e483cc4de740c46398385b03ffe0e662edae39.zip |
[SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover
Add a python API for the Stop Words Remover.
Author: Holden Karau <holden@pigscanfly.ca>
Closes #8118 from holdenk/SPARK-9679-python-StopWordsRemover.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 60e4237293..b892318f50 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,7 +31,7 @@ else: import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext +from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params @@ -258,7 +258,7 @@ class FeatureTests(PySparkTestCase): def test_ngram(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([ - ([["a", "b", "c", "d", "e"]])], ["input"]) + Row(input=["a", "b", "c", "d", "e"])]) ngram0 = NGram(n=4, inputCol="input", outputCol="output") self.assertEqual(ngram0.getN(), 4) self.assertEqual(ngram0.getInputCol(), "input") @@ -266,6 +266,22 @@ class FeatureTests(PySparkTestCase): transformedDF = ngram0.transform(dataset) self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + def test_stopwordsremover(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEquals(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["panda"]) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEquals(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a"]) + class HasInducedError(Params): |