diff options
author | yinxusen <yinxusen@gmail.com> | 2016-06-08 09:18:04 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-06-08 09:18:04 +0100 |
commit | 87706eb66cd1370862a1f8ea447484c80969e45f (patch) | |
tree | 5084c7fa634744777aadc3886d48de4d5b2e8839 /mllib/src/main/scala | |
parent | 91fbc880b69bddcf5310afecc49df1102408e1f3 (diff) | |
download | spark-87706eb66cd1370862a1f8ea447484c80969e45f.tar.gz spark-87706eb66cd1370862a1f8ea447484c80969e45f.tar.bz2 spark-87706eb66cd1370862a1f8ea447484c80969e45f.zip |
[SPARK-15793][ML] Add maxSentenceLength for ml.Word2Vec
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-15793
Word2vec in ML package should have maxSentenceLength method for feature parity.
## How was this patch tested?
Tested with Spark unit test.
Author: yinxusen <yinxusen@gmail.com>
Closes #13536 from yinxusen/SPARK-15793.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 2d89eb05a5..33515b2240 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -87,6 +87,21 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getMinCount: Int = $(minCount) + /** + * Sets the maximum length (in words) of each sentence in the input data. + * Any sentence longer than this threshold will be divided into chunks of + * up to `maxSentenceLength` size. + * Default: 1000 + * @group param + */ + final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " + + "(in words) of each sentence in the input data. Any sentence longer than this threshold will " + + "be divided into chunks up to the size.") + setDefault(maxSentenceLength -> 1000) + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + setDefault(stepSize -> 0.025) setDefault(maxIter -> 1) @@ -137,6 +152,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setMinCount(value: Int): this.type = set(minCount, value) + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) @@ -149,6 +167,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setSeed($(seed)) .setVectorSize($(vectorSize)) .setWindowSize($(windowSize)) + .setMaxSentenceLength($(maxSentenceLength)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } |