aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala19
1 files changed, 17 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 248288ca73..1b82b40caa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -100,10 +100,25 @@ class RegexTokenizer(override val uid: String)
/** @group getParam */
def getPattern: String = $(pattern)
- setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+")
+ /**
+ * Indicates whether to convert all characters to lowercase before tokenizing.
+ * Default: true
+ * @group param
+ */
+ final val toLowercase: BooleanParam = new BooleanParam(this, "toLowercase",
+ "whether to convert all characters to lowercase before tokenizing.")
+
+ /** @group setParam */
+ def setToLowercase(value: Boolean): this.type = set(toLowercase, value)
+
+ /** @group getParam */
+ def getToLowercase: Boolean = $(toLowercase)
+
+ setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true)
- override protected def createTransformFunc: String => Seq[String] = { str =>
+ override protected def createTransformFunc: String => Seq[String] = { originStr =>
val re = $(pattern).r
+ val str = if ($(toLowercase)) originStr.toLowerCase() else originStr
val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
val minLength = $(minTokenLength)
tokens.filter(_.length >= minLength)