aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-11-09 16:55:23 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-09 16:55:23 -0800
commit61f9c8711c79f35d67b0456155866da316b131d9 (patch)
treef8a120c315999ba1a459d8b2965b6a7646865df1 /mllib
parent7dc9d8dba6c4bc655896b137062d896dec4ef64a (diff)
downloadspark-61f9c8711c79f35d67b0456155866da316b131d9.tar.gz
spark-61f9c8711c79f35d67b0456155866da316b131d9.tar.bz2
spark-61f9c8711c79f35d67b0456155866da316b131d9.zip
[SPARK-11069][ML] Add RegexTokenizer option to convert to lowercase
jira: https://issues.apache.org/jira/browse/SPARK-11069 quotes from jira: Tokenizer converts strings to lowercase automatically, but RegexTokenizer does not. It would be nice to add an option to RegexTokenizer to convert to lowercase. Proposal: call the Boolean Param "toLowercase" set default to false (so behavior does not change) Actually sklearn converts to lowercase before tokenizing too Author: Yuhao Yang <hhbyyh@gmail.com> Closes #9092 from hhbyyh/tokenLower.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala19
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala22
3 files changed, 35 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 248288ca73..1b82b40caa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -100,10 +100,25 @@ class RegexTokenizer(override val uid: String)
/** @group getParam */
def getPattern: String = $(pattern)
- setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+")
+ /**
+ * Indicates whether to convert all characters to lowercase before tokenizing.
+ * Default: true
+ * @group param
+ */
+ final val toLowercase: BooleanParam = new BooleanParam(this, "toLowercase",
+ "whether to convert all characters to lowercase before tokenizing.")
+
+ /** @group setParam */
+ def setToLowercase(value: Boolean): this.type = set(toLowercase, value)
+
+ /** @group getParam */
+ def getToLowercase: Boolean = $(toLowercase)
+
+ setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true)
- override protected def createTransformFunc: String => Seq[String] = { str =>
+ override protected def createTransformFunc: String => Seq[String] = { originStr =>
val re = $(pattern).r
+ val str = if ($(toLowercase)) originStr.toLowerCase() else originStr
val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
val minLength = $(minTokenLength)
tokens.filter(_.length >= minLength)
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index 02309ce632..c407d98f1b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -53,6 +53,7 @@ public class JavaTokenizerSuite {
.setOutputCol("tokens")
.setPattern("\\s")
.setGaps(true)
+ .setToLowercase(false)
.setMinTokenLength(3);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index e5fd21c3f6..a02992a240 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -48,13 +48,13 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
.setInputCol("rawText")
.setOutputCol("tokens")
val dataset0 = sqlContext.createDataFrame(Seq(
- TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")),
- TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct"))
+ TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
+ TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
))
testRegexTokenizer(tokenizer0, dataset0)
val dataset1 = sqlContext.createDataFrame(Seq(
- TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")),
+ TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Array("punct"))
))
tokenizer0.setMinTokenLength(3)
@@ -64,11 +64,23 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
.setInputCol("rawText")
.setOutputCol("tokens")
val dataset2 = sqlContext.createDataFrame(Seq(
- TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")),
- TokenizerTestData("Te,st. punct", Array("Te,st.", "punct"))
+ TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
+ TokenizerTestData("Te,st. punct", Array("te,st.", "punct"))
))
testRegexTokenizer(tokenizer2, dataset2)
}
+
+ test("RegexTokenizer with toLowercase false") {
+ val tokenizer = new RegexTokenizer()
+ .setInputCol("rawText")
+ .setOutputCol("tokens")
+ .setToLowercase(false)
+ val dataset = sqlContext.createDataFrame(Seq(
+ TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
+ TokenizerTestData("java scala", Array("java", "scala"))
+ ))
+ testRegexTokenizer(tokenizer, dataset)
+ }
}
object RegexTokenizerSuite extends SparkFunSuite {