aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-08-01 02:31:28 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-01 02:31:28 -0700
commit8765665015ef47a23e00f7d01d4d280c31bb236d (patch)
treeec5107a65819712061aa4c3337f79367bc088519 /mllib/src/test
parentd2a9b66f6c0de89d6d16370af1c77c7f51b11d3e (diff)
downloadspark-8765665015ef47a23e00f7d01d4d280c31bb236d.tar.gz
spark-8765665015ef47a23e00f7d01d4d280c31bb236d.tar.bz2
spark-8765665015ef47a23e00f7d01d4d280c31bb236d.zip
[SPARK-8169] [ML] Add StopWordsRemover as a transformer
jira: https://issues.apache.org/jira/browse/SPARK-8169 stop words: http://en.wikipedia.org/wiki/Stop_words StopWordsRemover takes a string array column and outputs a string array column with all defined stop words removed. The transformer should also come with a standard set of stop words as default. Currently I used a minimum stop words set since on some [case](http://nlp.stanford.edu/IR-book/html/htmledition/dropping-common-terms-stop-words-1.html), small set of stop words is preferred. ASCII char has been tested, Yet I cannot check it in due to style check. Further thought, 1. Maybe I should use OpenHashSet. Is it recommended? 2. Currently I leave the null in input array untouched, i.e. Array(null, null) => Array(null, null). 3. If the current stop words set looks too limited, any suggestion for replacement? We can have something similar to the one in [SKlearn](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/feature_extraction/stop_words.py). Author: Yuhao Yang <hhbyyh@gmail.com> Closes #6742 from hhbyyh/stopwords and squashes the following commits: fa959d8 [Yuhao Yang] separating udf f190217 [Yuhao Yang] replace default list and other small fix 04403ab [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into stopwords b3aa957 [Yuhao Yang] add stopWordsRemover
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala80
1 files changed, 80 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
new file mode 100644
index 0000000000..f01306f89c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+object StopWordsRemoverSuite extends SparkFunSuite {
+ def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
+ t.transform(dataset)
+ .select("filtered", "expected")
+ .collect()
+ .foreach { case Row(tokens, wantedTokens) =>
+ assert(tokens === wantedTokens)
+ }
+ }
+}
+
+class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import StopWordsRemoverSuite._
+
+ test("StopWordsRemover default") {
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ val dataSet = sqlContext.createDataFrame(Seq(
+ (Seq("test", "test"), Seq("test", "test")),
+ (Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
+ (Seq("a", "the", "an"), Seq()),
+ (Seq("A", "The", "AN"), Seq()),
+ (Seq(null), Seq(null)),
+ (Seq(), Seq())
+ )).toDF("raw", "expected")
+
+ testStopWordsRemover(remover, dataSet)
+ }
+
+ test("StopWordsRemover case sensitive") {
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ .setCaseSensitive(true)
+ val dataSet = sqlContext.createDataFrame(Seq(
+ (Seq("A"), Seq("A")),
+ (Seq("The", "the"), Seq("The"))
+ )).toDF("raw", "expected")
+
+ testStopWordsRemover(remover, dataSet)
+ }
+
+ test("StopWordsRemover with additional words") {
+ val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala")
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ .setStopWords(stopWords)
+ val dataSet = sqlContext.createDataFrame(Seq(
+ (Seq("python", "scala", "a"), Seq()),
+ (Seq("Python", "Scala", "swift"), Seq("swift"))
+ )).toDF("raw", "expected")
+
+ testStopWordsRemover(remover, dataSet)
+ }
+}