aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala2
2 files changed, 12 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 7960f3cab5..d25a7cd5b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -71,7 +71,8 @@ class Word2Vec extends Serializable with Logging {
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
-
+ private var minCount = 5
+
/**
* Sets vector size (default: 100).
*/
@@ -114,6 +115,15 @@ class Word2Vec extends Serializable with Logging {
this
}
+ /**
+ * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
+ * model's vocabulary (default: 5).
+ */
+ def setMinCount(minCount: Int): this.type = {
+ this.minCount = minCount
+ this
+ }
+
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
@@ -122,9 +132,6 @@ class Word2Vec extends Serializable with Logging {
/** context words from [-window, window] */
private val window = 5
- /** minimum frequency to consider a vocabulary word */
- private val minCount = 5
-
private var trainWordsCount = 0
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 47d1a76fa3..01f3f90577 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -268,7 +268,7 @@ object Vectors {
* @param p norm.
* @return norm in L^p^ space.
*/
- private[spark] def norm(vector: Vector, p: Double): Double = {
+ def norm(vector: Vector, p: Double): Double = {
require(p >= 1.0)
val values = vector match {
case dv: DenseVector => dv.values