diff options
author | ganonp <ganonp@gmail.com> | 2014-12-29 15:31:19 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-12-29 15:31:19 -0800 |
commit | 343db392b58fb33a3e4bc6fda1da69aaf686b5a9 (patch) | |
tree | 303e41b14039d28601c4ca290e12347cd39b54f9 /mllib/src | |
parent | 6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464 (diff) | |
download | spark-343db392b58fb33a3e4bc6fda1da69aaf686b5a9.tar.gz spark-343db392b58fb33a3e4bc6fda1da69aaf686b5a9.tar.bz2 spark-343db392b58fb33a3e4bc6fda1da69aaf686b5a9.zip |
Added setMinCount to Word2Vec.scala
Wanted to customize the private minCount variable in the Word2Vec class. Added
a method to do so.
Author: ganonp <ganonp@gmail.com>
Closes #3693 from ganonp/my-custom-spark and squashes the following commits:
ad534f2 [ganonp] made norm method public
5110a6f [ganonp] Reorganized
854958b [ganonp] Fixed Indentation for setMinCount
12ed8f9 [ganonp] Update Word2Vec.scala
76bdf5a [ganonp] Update Word2Vec.scala
ffb88bb [ganonp] Update Word2Vec.scala
5eb9100 [ganonp] Added setMinCount to Word2Vec.scala
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 15 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 |
2 files changed, 12 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7960f3cab5..d25a7cd5b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -71,7 +71,8 @@ class Word2Vec extends Serializable with Logging { private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() - + private var minCount = 5 + /** * Sets vector size (default: 100). */ @@ -114,6 +115,15 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + * model's vocabulary (default: 5). + */ + def setMinCount(minCount: Int): this.type = { + this.minCount = minCount + this + } + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -122,9 +132,6 @@ class Word2Vec extends Serializable with Logging { /** context words from [-window, window] */ private val window = 5 - /** minimum frequency to consider a vocabulary word */ - private val minCount = 5 - private var trainWordsCount = 0 private var vocabSize = 0 private var vocab: Array[VocabWord] = null diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 47d1a76fa3..01f3f90577 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -268,7 +268,7 @@ object Vectors { * @param p norm. * @return norm in L^p^ space. */ - private[spark] def norm(vector: Vector, p: Double): Double = { + def norm(vector: Vector, p: Double): Double = { require(p >= 1.0) val values = vector match { case dv: DenseVector => dv.values |