aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAnthony Truchet <a.truchet@criteo.com>2017-03-08 11:44:25 +0000
committerSean Owen <sowen@cloudera.com>2017-03-08 11:44:25 +0000
commit9ea201cf6482c9c62c9428759d238063db62d66e (patch)
tree14e8ef949cc4e1be20baed302c43ae2b90245613 /mllib
parent3f9f9180c2e695ad468eb813df5feec41e169531 (diff)
downloadspark-9ea201cf6482c9c62c9428759d238063db62d66e.tar.gz
spark-9ea201cf6482c9c62c9428759d238063db62d66e.tar.bz2
spark-9ea201cf6482c9c62c9428759d238063db62d66e.zip
[SPARK-16440][MLLIB] Ensure broadcasted variables are destroyed even in case of exception
## What changes were proposed in this pull request? Ensure broadcasted variable are destroyed even in case of exception ## How was this patch tested? Word2VecSuite was run locally Author: Anthony Truchet <a.truchet@criteo.com> Closes #14299 from AnthonyTruchet/SPARK-16440.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala18
1 files changed, 15 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 2364d43aaa..531c8b0791 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging {
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
+ try {
+ doFit(dataset, sc, expTable, bcVocab, bcVocabHash)
+ } finally {
+ expTable.destroy(blocking = false)
+ bcVocab.destroy(blocking = false)
+ bcVocabHash.destroy(blocking = false)
+ }
+ }
+
+ private def doFit[S <: Iterable[String]](
+ dataset: RDD[S], sc: SparkContext,
+ expTable: Broadcast[Array[Float]],
+ bcVocab: Broadcast[Array[VocabWord]],
+ bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {
// each partition is a collection of sentences,
// will be translated into arrays of Index integer
val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
@@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging {
bcSyn1Global.destroy(false)
}
newSentences.unpersist()
- expTable.destroy(false)
- bcVocab.destroy(false)
- bcVocabHash.destroy(false)
val wordArray = vocab.map(_.word)
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)