diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 2d89eb05a5..33515b2240 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -87,6 +87,21 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getMinCount: Int = $(minCount) + /** + * Sets the maximum length (in words) of each sentence in the input data. + * Any sentence longer than this threshold will be divided into chunks of + * up to `maxSentenceLength` size. + * Default: 1000 + * @group param + */ + final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " + + "(in words) of each sentence in the input data. Any sentence longer than this threshold will " + + "be divided into chunks up to the size.") + setDefault(maxSentenceLength -> 1000) + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + setDefault(stepSize -> 0.025) setDefault(maxIter -> 1) @@ -137,6 +152,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setMinCount(value: Int): this.type = set(minCount, value) + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) @@ -149,6 +167,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setSeed($(seed)) .setVectorSize($(vectorSize)) .setWindowSize($(windowSize)) + .setMaxSentenceLength($(maxSentenceLength)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } |