diff options
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 2364d43aaa..531c8b0791 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging { val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) + try { + doFit(dataset, sc, expTable, bcVocab, bcVocabHash) + } finally { + expTable.destroy(blocking = false) + bcVocab.destroy(blocking = false) + bcVocabHash.destroy(blocking = false) + } + } + + private def doFit[S <: Iterable[String]]( + dataset: RDD[S], sc: SparkContext, + expTable: Broadcast[Array[Float]], + bcVocab: Broadcast[Array[VocabWord]], + bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = { // each partition is a collection of sentences, // will be translated into arrays of Index integer val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => @@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging { bcSyn1Global.destroy(false) } newSentences.unpersist() - expTable.destroy(false) - bcVocab.destroy(false) - bcVocabHash.destroy(false) val wordArray = vocab.map(_.word) new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) |