aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiquan Pei <liquanpei@gmail.com>2014-08-12 00:28:00 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-12 00:28:00 -0700
commitf0060b75ff67ab60babf54149a6860edc53cb6e9 (patch)
tree16bb768ddbe7d3b431bf9feb64ce63dff1ffea6a /mllib
parent9038d94e1e50e05de00fd51af4fd7b9280481cdc (diff)
downloadspark-f0060b75ff67ab60babf54149a6860edc53cb6e9.tar.gz
spark-f0060b75ff67ab60babf54149a6860edc53cb6e9.tar.bz2
spark-f0060b75ff67ab60babf54149a6860edc53cb6e9.zip
[MLlib] Correctly set vectorSize and alpha
mengxr Correctly set vectorSize and alpha in Word2Vec training. Author: Liquan Pei <liquanpei@gmail.com> Closes #1900 from Ishiihara/Word2Vec-bugfix and squashes the following commits: 85f64f2 [Liquan Pei] correctly set vectorSize and alpha
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala25
1 files changed, 12 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 395037e1ec..ecd49ea2ff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -119,7 +119,6 @@ class Word2Vec extends Serializable with Logging {
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
- private val layer1Size = vectorSize
/** context words from [-window, window] */
private val window = 5
@@ -131,7 +130,6 @@ class Word2Vec extends Serializable with Logging {
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
private var vocabHash = mutable.HashMap.empty[String, Int]
- private var alpha = startingAlpha
private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
@@ -287,9 +285,10 @@ class Word2Vec extends Serializable with Logging {
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
var syn0Global =
- Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
- var syn1Global = new Array[Float](vocabSize * layer1Size)
+ Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
+ var syn1Global = new Array[Float](vocabSize * vectorSize)
+ var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
@@ -317,24 +316,24 @@ class Word2Vec extends Serializable with Logging {
val c = pos - window + a
if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c)
- val l1 = lastWord * layer1Size
- val neu1e = new Array[Float](layer1Size)
+ val l1 = lastWord * vectorSize
+ val neu1e = new Array[Float](vectorSize)
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
- val l2 = bcVocab.value(word).point(d) * layer1Size
+ val l2 = bcVocab.value(word).point(d) * vectorSize
// Propagate hidden -> output
- var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1)
+ var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
- blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
- blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
+ blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
+ blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
}
d += 1
}
- blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1)
+ blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
}
}
a += 1
@@ -365,8 +364,8 @@ class Word2Vec extends Serializable with Logging {
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
- val vector = new Array[Float](layer1Size)
- Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
+ val vector = new Array[Float](vectorSize)
+ Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
word2VecMap += word -> vector
i += 1
}