aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorganonp <ganonp@gmail.com>2014-12-29 15:31:19 -0800
committerXiangrui Meng <meng@databricks.com>2014-12-29 15:31:19 -0800
commit343db392b58fb33a3e4bc6fda1da69aaf686b5a9 (patch)
tree303e41b14039d28601c4ca290e12347cd39b54f9 /mllib
parent6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464 (diff)
downloadspark-343db392b58fb33a3e4bc6fda1da69aaf686b5a9.tar.gz
spark-343db392b58fb33a3e4bc6fda1da69aaf686b5a9.tar.bz2
spark-343db392b58fb33a3e4bc6fda1da69aaf686b5a9.zip
Added setMinCount to Word2Vec.scala
Wanted to customize the private minCount variable in the Word2Vec class. Added a method to do so. Author: ganonp <ganonp@gmail.com> Closes #3693 from ganonp/my-custom-spark and squashes the following commits: ad534f2 [ganonp] made norm method public 5110a6f [ganonp] Reorganized 854958b [ganonp] Fixed Indentation for setMinCount 12ed8f9 [ganonp] Update Word2Vec.scala 76bdf5a [ganonp] Update Word2Vec.scala ffb88bb [ganonp] Update Word2Vec.scala 5eb9100 [ganonp] Added setMinCount to Word2Vec.scala
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala2
2 files changed, 12 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 7960f3cab5..d25a7cd5b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -71,7 +71,8 @@ class Word2Vec extends Serializable with Logging {
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
-
+ private var minCount = 5
+
/**
* Sets vector size (default: 100).
*/
@@ -114,6 +115,15 @@ class Word2Vec extends Serializable with Logging {
this
}
+ /**
+ * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
+ * model's vocabulary (default: 5).
+ */
+ def setMinCount(minCount: Int): this.type = {
+ this.minCount = minCount
+ this
+ }
+
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
@@ -122,9 +132,6 @@ class Word2Vec extends Serializable with Logging {
/** context words from [-window, window] */
private val window = 5
- /** minimum frequency to consider a vocabulary word */
- private val minCount = 5
-
private var trainWordsCount = 0
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 47d1a76fa3..01f3f90577 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -268,7 +268,7 @@ object Vectors {
* @param p norm.
* @return norm in L^p^ space.
*/
- private[spark] def norm(vector: Vector, p: Double): Double = {
+ def norm(vector: Vector, p: Double): Double = {
require(p >= 1.0)
val values = vector match {
case dv: DenseVector => dv.values