aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala50
1 files changed, 35 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index ecd49ea2ff..d2ae62b482 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
/**
* Entry in vocabulary
@@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging {
var syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn1Global = new Array[Float](vocabSize * vectorSize)
-
var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
+ val syn0Modify = new Array[Int](vocabSize)
+ val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
@@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging {
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
- val l2 = bcVocab.value(word).point(d) * vectorSize
+ val inner = bcVocab.value(word).point(d)
+ val l2 = inner * vectorSize
// Propagate hidden -> output
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
@@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging {
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
+ syn1Modify(inner) += 1
}
d += 1
}
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
+ syn0Modify(lastWord) += 1
}
}
a += 1
@@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging {
}
(syn0, syn1, lwc, wc)
}
- Iterator(model)
+ val syn0Local = model._1
+ val syn1Local = model._2
+ val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
+ var index = 0
+ while(index < vocabSize) {
+ if (syn0Modify(index) != 0) {
+ synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
+ }
+ if (syn1Modify(index) != 0) {
+ synOut.update(index + vocabSize,
+ syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
+ }
+ index += 1
+ }
+ Iterator(synOut)
}
- val (aggSyn0, aggSyn1, _, _) =
- partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
- val n = syn0_1.length
- val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
- val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
- blas.sscal(n, weight1, syn0_1, 1)
- blas.sscal(n, weight1, syn1_1, 1)
- blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
- blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
- (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
+ val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
+ blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
+ v1
+ }.collect()
+ var i = 0
+ while (i < synAgg.length) {
+ val index = synAgg(i)._1
+ if (index < vocabSize) {
+ Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
+ } else {
+ Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
}
- syn0Global = aggSyn0
- syn1Global = aggSyn1
+ i += 1
+ }
}
newSentences.unpersist()