aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-09-01 10:48:57 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-01 10:48:57 -0700
commite6e483cc4de740c46398385b03ffe0e662edae39 (patch)
tree652cf519f902aaaf8eecc564791690b395aea81b /python/pyspark/ml/tests.py
parent391e6be0ae883f3ea0fab79463eb8b618af79afb (diff)
downloadspark-e6e483cc4de740c46398385b03ffe0e662edae39.tar.gz
spark-e6e483cc4de740c46398385b03ffe0e662edae39.tar.bz2
spark-e6e483cc4de740c46398385b03ffe0e662edae39.zip
[SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover
Add a python API for the Stop Words Remover. Author: Holden Karau <holden@pigscanfly.ca> Closes #8118 from holdenk/SPARK-9679-python-StopWordsRemover.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py20
1 files changed, 18 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 60e4237293..b892318f50 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -31,7 +31,7 @@ else:
import unittest
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-from pyspark.sql import DataFrame, SQLContext
+from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.param import Param, Params
@@ -258,7 +258,7 @@ class FeatureTests(PySparkTestCase):
def test_ngram(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([
- ([["a", "b", "c", "d", "e"]])], ["input"])
+ Row(input=["a", "b", "c", "d", "e"])])
ngram0 = NGram(n=4, inputCol="input", outputCol="output")
self.assertEqual(ngram0.getN(), 4)
self.assertEqual(ngram0.getInputCol(), "input")
@@ -266,6 +266,22 @@ class FeatureTests(PySparkTestCase):
transformedDF = ngram0.transform(dataset)
self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
+ def test_stopwordsremover(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
+ stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
+ # Default
+ self.assertEquals(stopWordRemover.getInputCol(), "input")
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEquals(transformedDF.head().output, ["panda"])
+ # Custom
+ stopwords = ["panda"]
+ stopWordRemover.setStopWords(stopwords)
+ self.assertEquals(stopWordRemover.getInputCol(), "input")
+ self.assertEquals(stopWordRemover.getStopWords(), stopwords)
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEquals(transformedDF.head().output, ["a"])
+
class HasInducedError(Params):