aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authoryinxusen <yinxusen@gmail.com>2016-06-08 09:18:04 +0100
committerSean Owen <sowen@cloudera.com>2016-06-08 09:18:04 +0100
commit87706eb66cd1370862a1f8ea447484c80969e45f (patch)
tree5084c7fa634744777aadc3886d48de4d5b2e8839 /mllib
parent91fbc880b69bddcf5310afecc49df1102408e1f3 (diff)
downloadspark-87706eb66cd1370862a1f8ea447484c80969e45f.tar.gz
spark-87706eb66cd1370862a1f8ea447484c80969e45f.tar.bz2
spark-87706eb66cd1370862a1f8ea447484c80969e45f.zip
[SPARK-15793][ML] Add maxSentenceLength for ml.Word2Vec
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-15793 Word2vec in ML package should have maxSentenceLength method for feature parity. ## How was this patch tested? Tested with Spark unit test. Author: yinxusen <yinxusen@gmail.com> Closes #13536 from yinxusen/SPARK-15793.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala1
2 files changed, 20 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 2d89eb05a5..33515b2240 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -87,6 +87,21 @@ private[feature] trait Word2VecBase extends Params
/** @group getParam */
def getMinCount: Int = $(minCount)
+ /**
+ * Sets the maximum length (in words) of each sentence in the input data.
+ * Any sentence longer than this threshold will be divided into chunks of
+ * up to `maxSentenceLength` size.
+ * Default: 1000
+ * @group param
+ */
+ final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " +
+ "(in words) of each sentence in the input data. Any sentence longer than this threshold will " +
+ "be divided into chunks up to the size.")
+ setDefault(maxSentenceLength -> 1000)
+
+ /** @group getParam */
+ def getMaxSentenceLength: Int = $(maxSentenceLength)
+
setDefault(stepSize -> 0.025)
setDefault(maxIter -> 1)
@@ -137,6 +152,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
/** @group setParam */
def setMinCount(value: Int): this.type = set(minCount, value)
+ /** @group setParam */
+ def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, value)
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): Word2VecModel = {
transformSchema(dataset.schema, logging = true)
@@ -149,6 +167,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
.setSeed($(seed))
.setVectorSize($(vectorSize))
.setWindowSize($(windowSize))
+ .setMaxSentenceLength($(maxSentenceLength))
.fit(input)
copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 280a36f56e..16c74f6785 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -191,6 +191,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setSeed(42L)
.setStepSize(0.01)
.setVectorSize(100)
+ .setMaxSentenceLength(500)
testDefaultReadWrite(t)
}