aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala18
1 files changed, 15 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 2364d43aaa..531c8b0791 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging {
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
+ try {
+ doFit(dataset, sc, expTable, bcVocab, bcVocabHash)
+ } finally {
+ expTable.destroy(blocking = false)
+ bcVocab.destroy(blocking = false)
+ bcVocabHash.destroy(blocking = false)
+ }
+ }
+
+ private def doFit[S <: Iterable[String]](
+ dataset: RDD[S], sc: SparkContext,
+ expTable: Broadcast[Array[Float]],
+ bcVocab: Broadcast[Array[VocabWord]],
+ bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {
// each partition is a collection of sentences,
// will be translated into arrays of Index integer
val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
@@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging {
bcSyn1Global.destroy(false)
}
newSentences.unpersist()
- expTable.destroy(false)
- bcVocab.destroy(false)
- bcVocabHash.destroy(false)
val wordArray = vocab.map(_.word)
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)