aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala7
1 files changed, 6 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index a47f27b0af..655ac0bb55 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging {
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = learningRate
+
for (k <- 1 to numIterations) {
+ val bcSyn0Global = sc.broadcast(syn0Global)
+ val bcSyn1Global = sc.broadcast(syn1Global)
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
- val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
+ val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
@@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging {
}
i += 1
}
+ bcSyn0Global.unpersist(false)
+ bcSyn1Global.unpersist(false)
}
newSentences.unpersist()