aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorAugustin Borsu <augustin@sagacify.com>2015-03-25 10:16:39 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-25 10:16:39 -0700
commit982952f4aebb474823dd886dd2b18f4277bd7c30 (patch)
treec2871dce7280ce6b7aef158890649b0d99aa573b /mllib/src/test
parent10c78607b2724f5a64b0cdb966e9c5805f23919b (diff)
downloadspark-982952f4aebb474823dd886dd2b18f4277bd7c30.tar.gz
spark-982952f4aebb474823dd886dd2b18f4277bd7c30.tar.bz2
spark-982952f4aebb474823dd886dd2b18f4277bd7c30.zip
[ML][FEATURE] SPARK-5566: RegEx Tokenizer
Added a Regex based tokenizer for ml. Currently the regex is fixed but if I could add a regex type paramater to the paramMap, changing the tokenizer regex could be a parameter used in the crossValidation. Also I wonder what would be the best way to add a stop word list. Author: Augustin Borsu <augustin@sagacify.com> Author: Augustin Borsu <a.borsu@gmail.com> Author: Augustin Borsu <aborsu@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #4504 from aborsu985/master and squashes the following commits: 716d257 [Augustin Borsu] Merge branch 'mengxr-SPARK-5566' cb07021 [Augustin Borsu] Merge branch 'SPARK-5566' of git://github.com/mengxr/spark into mengxr-SPARK-5566 5f09434 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' a164800 [Xiangrui Meng] remove tabs 556aa27 [Xiangrui Meng] Merge branch 'aborsu985-master' into SPARK-5566 9651aec [Xiangrui Meng] update test f96526d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5566 2338da5 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' e88d7b8 [Xiangrui Meng] change pattern to a StringParameter; update tests 148126f [Augustin Borsu] Added return type to public functions 12dddb4 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' daf685e [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 6a85982 [Augustin Borsu] Style corrections 38b95a1 [Augustin Borsu] Added Java unit test for RegexTokenizer b66313f [Augustin Borsu] Modified the pattern Param so it is compiled when given to the Tokenizer e262bac [Augustin Borsu] Added unit tests in scala cd6642e [Augustin Borsu] Changed regex to pattern 132b00b [Augustin Borsu] Changed matching to gaps and removed case folding 201a107 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' cb9c9a7 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' d3ef6d3 [Augustin Borsu] Added doc to RegexTokenizer 9082fc3 [Augustin Borsu] Removed stopwords parameters and updated doc 19f9e53 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' f6a5002 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 7f930bb [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 77ff9ca [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 2e89719 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 196cd7a [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 11ca50f [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 9f8685a [Augustin Borsu] RegexTokenizer 9e07a78 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 9547e9d [Augustin Borsu] RegEx Tokenizer 01cd26f [Augustin Borsu] RegExTokenizer
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java71
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala85
2 files changed, 156 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
new file mode 100644
index 0000000000..3806f65002
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaTokenizerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void regexTokenizer() {
+ RegexTokenizer myRegExTokenizer = new RegexTokenizer()
+ .setInputCol("rawText")
+ .setOutputCol("tokens")
+ .setPattern("\\s")
+ .setGaps(true)
+ .setMinTokenLength(3);
+
+ JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Lists.newArrayList(
+ new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
+ new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
+ ));
+ DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
+
+ Row[] pairs = myRegExTokenizer.transform(dataset)
+ .select("tokens", "wantedTokens")
+ .collect();
+
+ for (Row r : pairs) {
+ Assert.assertEquals(r.get(0), r.get(1));
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
new file mode 100644
index 0000000000..bf862b912d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.BeanInfo
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+@BeanInfo
+case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) {
+ /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */
+ def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq)
+}
+
+class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
+ import org.apache.spark.ml.feature.RegexTokenizerSuite._
+
+ @transient var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ test("RegexTokenizer") {
+ val tokenizer = new RegexTokenizer()
+ .setInputCol("rawText")
+ .setOutputCol("tokens")
+
+ val dataset0 = sqlContext.createDataFrame(Seq(
+ TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
+ TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
+ ))
+ testRegexTokenizer(tokenizer, dataset0)
+
+ val dataset1 = sqlContext.createDataFrame(Seq(
+ TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")),
+ TokenizerTestData("Te,st. punct", Seq("punct"))
+ ))
+
+ tokenizer.setMinTokenLength(3)
+ testRegexTokenizer(tokenizer, dataset1)
+
+ tokenizer
+ .setPattern("\\s")
+ .setGaps(true)
+ .setMinTokenLength(0)
+ val dataset2 = sqlContext.createDataFrame(Seq(
+ TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
+ TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct"))
+ ))
+ testRegexTokenizer(tokenizer, dataset2)
+ }
+}
+
+object RegexTokenizerSuite extends FunSuite {
+
+ def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
+ t.transform(dataset)
+ .select("tokens", "wantedTokens")
+ .collect()
+ .foreach {
+ case Row(tokens, wantedTokens) =>
+ assert(tokens === wantedTokens)
+ }
+ }
+}